From 37042d5b6782baaab926a194d6f5b48dd2905df4 Mon Sep 17 00:00:00 2001 From: zhaojichen Date: Thu, 16 Apr 2020 08:08:07 -0400 Subject: [PATCH 01/10] add global batch normalization --- mindspore/nn/layer/__init__.py | 4 +- mindspore/nn/layer/normalization.py | 130 ++++++++++++++++++++++++--- tests/ut/python/nn/test_batchnorm.py | 15 ++++ 3 files changed, 133 insertions(+), 16 deletions(-) diff --git a/mindspore/nn/layer/__init__.py b/mindspore/nn/layer/__init__.py index cf601f03ff..714b517a84 100644 --- a/mindspore/nn/layer/__init__.py +++ b/mindspore/nn/layer/__init__.py @@ -18,7 +18,7 @@ Layer. The high-level components(Cells) used to construct the neural network. """ from .activation import Softmax, LogSoftmax, ReLU, ReLU6, Tanh, GELU, ELU, Sigmoid, PReLU, get_activation, LeakyReLU, HSigmoid, HSwish -from .normalization import BatchNorm1d, BatchNorm2d, LayerNorm, GroupNorm +from .normalization import BatchNorm1d, BatchNorm2d, LayerNorm, GroupNorm, GlobalBatchNorm from .container import SequentialCell, CellList from .conv import Conv2d, Conv2dTranspose from .lstm import LSTM @@ -29,7 +29,7 @@ from .image import ImageGradients, SSIM __all__ = ['Softmax', 'LogSoftmax', 'ReLU', 'ReLU6', 'Tanh', 'GELU', 'Sigmoid', 'PReLU', 'get_activation', 'LeakyReLU', 'HSigmoid', 'HSwish', 'ELU', - 'BatchNorm1d', 'BatchNorm2d', 'LayerNorm', 'GroupNorm', + 'BatchNorm1d', 'BatchNorm2d', 'LayerNorm', 'GroupNorm', 'GlobalBatchNorm', 'SequentialCell', 'CellList', 'Conv2d', 'Conv2dTranspose', 'LSTM', diff --git a/mindspore/nn/layer/normalization.py b/mindspore/nn/layer/normalization.py index 4aafaf031e..4bfa222986 100644 --- a/mindspore/nn/layer/normalization.py +++ b/mindspore/nn/layer/normalization.py @@ -20,16 +20,28 @@ from mindspore.common.initializer import initializer from mindspore.common.tensor import Tensor import mindspore.common.dtype as mstype import mindspore.context as context -from mindspore._checkparam import check_int_positive, check_bool, check_typename +from mindspore._checkparam import check_bool, check_typename from mindspore._extends import cell_attr_register +from mindspore.communication.management import get_local_rank_size, get_rank +from mindspore.communication import management +from mindspore._checkparam import check_int_positive from ..cell import Cell +class _GlobalBNHelper(Cell): + def __init__(self, group): + super(_GlobalBNHelper, self).__init__() + self.group = group + self.reduce = P.AllReduce(P.ReduceOp.SUM, group=self.group).add_prim_attr('fusion', 1) + def construct(self, x): + x = self.reduce(x) + return x class _BatchNorm(Cell): """Batch Normalization base class.""" @cell_attr_register def __init__(self, num_features, + group=1, eps=1e-5, momentum=0.9, affine=True, @@ -56,6 +68,20 @@ class _BatchNorm(Cell): gamma_init, num_features), name="gamma", requires_grad=affine) self.beta = Parameter(initializer( beta_init, num_features), name="beta", requires_grad=affine) + self.group = check_int_positive(group) + self.rank_id = get_rank() + self.rank_size = get_local_rank_size() + self.device_list = [i for i in range(0, self.rank_size)] + self.rank_list = self.list_group(self.device_list, self.group) + self.rank_list_idx = len(self.rank_list) + for i in range(self.rank_list_idx): + if self.rank_id in self.rank_list[i] and self.group != 1: + self.is_global = True + management.create_group('group' + str(i), self.rank_list[i]) + self.all_reduce = _GlobalBNHelper('group' + str(i)) + self.shape = P.Shape() + self.reduce_mean = P.ReduceMean() + self.square = P.Square() if context.get_context("enable_ge"): self.is_ge_backend = True @@ -82,22 +108,52 @@ class _BatchNorm(Cell): def _check_data_dim(self, x): raise NotImplementedError + def list_group(self, world_rank, group_size): + if group_size > get_local_rank_size(): + raise ValueError("group size can not be greater than local rank size, group size is {}, local_rank_size is {}".format(group_size, get_local_rank_size())) + if len(world_rank) % group_size != 0: + raise ValueError("please make your group size correct.") + world_rank_list = zip(*(iter(world_rank),) *group_size) + group_list = [list(i) for i in world_rank_list] + return group_list + def construct(self, x): if self.training and self.use_batch_statistics: if self.is_ge_backend: - y, batch_mean, batch_var, _, _ = \ - self.bn_train(x, - self.gamma, - self.beta, - None, - None) - - mean_sub = self.sub_mean(self.moving_mean, batch_mean) - temp_mean = self.mul_mean(mean_sub, self.momentum) - mean_sub2 = self.sub_var(self.moving_variance, batch_var) - temp_variance = self.mul_var(mean_sub2, self.momentum) - y = F.depend(y, self.assign_sub_mean(self.moving_mean, temp_mean)) - y = F.depend(y, self.assign_sub_var(self.moving_variance, temp_variance)) + if self.is_global: + x_mean = self.reduce_mean(x) + x_mean_square = self.reduce_mean(self.square(x)) + global_batch_mean = self.all_reduce(x_mean) / self.group + global_batch_mean_square = self.all_reduce(x_mean_square) / self.group + global_mean = global_batch_mean + global_var = global_batch_mean_square - self.square(global_batch_mean) + y, batch_mean, batch_var, _, _ = \ + self.bn_train(x, + self.gamma, + self.beta, + None, + None) + + mean_sub = self.sub_mean(self.moving_mean, global_mean) + temp_mean = self.mul_mean(mean_sub, self.momentum) + mean_sub2 = self.sub_var(self.moving_variance, global_var) + temp_variance = self.mul_var(mean_sub2, self.momentum) + y = F.depend(y, self.assign_sub_mean(self.moving_mean, temp_mean)) + y = F.depend(y, self.assign_sub_var(self.moving_variance, temp_variance)) + else: + y, batch_mean, batch_var, _, _ = \ + self.bn_train(x, + self.gamma, + self.beta, + None, + None) + + mean_sub = self.sub_mean(self.moving_mean, batch_mean) + temp_mean = self.mul_mean(mean_sub, self.momentum) + mean_sub2 = self.sub_var(self.moving_variance, batch_var) + temp_variance = self.mul_var(mean_sub2, self.momentum) + y = F.depend(y, self.assign_sub_mean(self.moving_mean, temp_mean)) + y = F.depend(y, self.assign_sub_var(self.moving_variance, temp_variance)) else: y = self.bn_train(x, self.gamma, @@ -221,6 +277,52 @@ class BatchNorm2d(_BatchNorm): pass +class GlobalBatchNorm(_BatchNorm): + r""" + Global normalization layer over a N-dimension input. + + Global Normalization is cross device synchronized batch normalization. Batch Normalization implementation + only normalize the data within each device. Global normalization will normalize the input within the group. + It has been described in the paper `Batch Normalization: Accelerating Deep Network Training by + Reducing Internal Covariate Shift `_. It rescales and recenters the + feature using a mini-batch of data and the learned parameters which can be described in the following formula. + + .. math:: + y = \frac{x - \mathrm{E}[x]}{\sqrt{\mathrm{Var}[x] + \epsilon}} * \gamma + \beta + + Args: + num_features (int): `C` from an expected input of size (N, C, H, W). + group (int): The number of device in each group. + eps (float): A value added to the denominator for numerical stability. Default: 1e-5. + momentum (float): A floating hyperparameter of the momentum for the + running_mean and running_var computation. Default: 0.9. + gamma_init (Union[Tensor, str, Initializer, numbers.Number]): Initializer for the gamma weight. + The values of str refer to the function `initializer` including 'zeros', 'ones', 'xavier_uniform', + 'he_uniform', etc. Default: 'ones'. + beta_init (Union[Tensor, str, Initializer, numbers.Number]): Initializer for the beta weight. + The values of str refer to the function `initializer` including 'zeros', 'ones', 'xavier_uniform', + 'he_uniform', etc. Default: 'zeros'. + moving_mean_init (Union[Tensor, str, Initializer, numbers.Number]): Initializer for the moving mean. + The values of str refer to the function `initializer` including 'zeros', 'ones', 'xavier_uniform', + 'he_uniform', etc. Default: 'zeros'. + moving_var_init (Union[Tensor, str, Initializer, numbers.Number]): Initializer for the moving variance. + The values of str refer to the function `initializer` including 'zeros', 'ones', 'xavier_uniform', + 'he_uniform', etc. Default: 'ones'. + use_batch_statistics (bool): If true, use the mean value and variance value of current batch data, else use + the mean value and variance value of specified value. Default: True. + + Inputs: + - **input** (Tensor) - Tensor of shape :math:`(N, C_{in}, H_{in}, W_{in})`. + + Outputs: + Tensor, the normalized, scaled, offset tensor, of shape :math:`(N, C_{out}, H_{out}, W_{out})`. + + Examples: + >>> global_bn_op = nn.GlobalBatchNorm(num_features=3, group=4) + >>> input = Tensor(np.random.randint(0, 255, [1, 3, 224, 224]), mindspore.float32) + >>> global_bn_op(input) + """ + class LayerNorm(Cell): r""" Applies Layer Normalization over a mini-batch of inputs. diff --git a/tests/ut/python/nn/test_batchnorm.py b/tests/ut/python/nn/test_batchnorm.py index e73b7ebbf0..23ca79e8e1 100644 --- a/tests/ut/python/nn/test_batchnorm.py +++ b/tests/ut/python/nn/test_batchnorm.py @@ -19,6 +19,7 @@ import pytest import mindspore.nn as nn from mindspore.common.api import _executor from mindspore import Tensor, Parameter +from mindspore.communication.management import init def test_bn_pars_valid1(): @@ -70,3 +71,17 @@ def test_compile_groupnorm(): net = nn.GroupNorm(16, 64) input_data = Tensor(np.random.rand(1,64,256,256).astype(np.float32)) _executor.compile(net, input_data) + +class GlobalBNNet(nn.Cell): + def __init__(self): + super(GlobalBNNet, self).__init__() + self.bn = nn.GlobalBatchNorm(num_features = 2, group = 4) + def construct(self, x): + return self.bn(x) + +def test_gloabl_bn(): + init("hccl") + net = GlobalBNNet() + input_data = Tensor(np.array([[2.4, 2.1], [3.2, 5.4]], dtype=np.float32)) + net.set_train() + out = net(input_data) From 27c307684901a9ef7ee1ee8327191f59e1012748 Mon Sep 17 00:00:00 2001 From: zhaojichen Date: Thu, 16 Apr 2020 10:00:35 -0400 Subject: [PATCH 02/10] add global batch normalization --- mindspore/nn/layer/normalization.py | 27 ++++++++++++++++----------- tests/ut/python/nn/test_batchnorm.py | 11 +++++++++-- 2 files changed, 25 insertions(+), 13 deletions(-) diff --git a/mindspore/nn/layer/normalization.py b/mindspore/nn/layer/normalization.py index 4bfa222986..32e4be5998 100644 --- a/mindspore/nn/layer/normalization.py +++ b/mindspore/nn/layer/normalization.py @@ -69,16 +69,17 @@ class _BatchNorm(Cell): self.beta = Parameter(initializer( beta_init, num_features), name="beta", requires_grad=affine) self.group = check_int_positive(group) - self.rank_id = get_rank() - self.rank_size = get_local_rank_size() - self.device_list = [i for i in range(0, self.rank_size)] - self.rank_list = self.list_group(self.device_list, self.group) - self.rank_list_idx = len(self.rank_list) - for i in range(self.rank_list_idx): - if self.rank_id in self.rank_list[i] and self.group != 1: - self.is_global = True - management.create_group('group' + str(i), self.rank_list[i]) - self.all_reduce = _GlobalBNHelper('group' + str(i)) + if self.group != 1: + self.rank_id = get_rank() + self.rank_size = get_local_rank_size() + self.device_list = [i for i in range(0, self.rank_size)] + self.rank_list = self.list_group(self.device_list, self.group) + self.rank_list_idx = len(self.rank_list) + for i in range(self.rank_list_idx): + if self.rank_id in self.rank_list[i] and self.group != 1: + self.is_global = True + management.create_group('group' + str(i), self.rank_list[i]) + self.all_reduce = _GlobalBNHelper('group' + str(i)) self.shape = P.Shape() self.reduce_mean = P.ReduceMean() self.square = P.Square() @@ -110,7 +111,8 @@ class _BatchNorm(Cell): def list_group(self, world_rank, group_size): if group_size > get_local_rank_size(): - raise ValueError("group size can not be greater than local rank size, group size is {}, local_rank_size is {}".format(group_size, get_local_rank_size())) + raise ValueError("group size can not be greater than local rank size, group size is {}, local_rank_size is {}".format( + group_size, get_local_rank_size())) if len(world_rank) % group_size != 0: raise ValueError("please make your group size correct.") world_rank_list = zip(*(iter(world_rank),) *group_size) @@ -322,6 +324,9 @@ class GlobalBatchNorm(_BatchNorm): >>> input = Tensor(np.random.randint(0, 255, [1, 3, 224, 224]), mindspore.float32) >>> global_bn_op(input) """ + def _check_data_dim(self, x): + if x.dim == 0: + pass class LayerNorm(Cell): r""" diff --git a/tests/ut/python/nn/test_batchnorm.py b/tests/ut/python/nn/test_batchnorm.py index 23ca79e8e1..24f0de85f7 100644 --- a/tests/ut/python/nn/test_batchnorm.py +++ b/tests/ut/python/nn/test_batchnorm.py @@ -20,6 +20,8 @@ import mindspore.nn as nn from mindspore.common.api import _executor from mindspore import Tensor, Parameter from mindspore.communication.management import init +from mindspore import context +from mindspore import ParallelMode def test_bn_pars_valid1(): @@ -75,12 +77,17 @@ def test_compile_groupnorm(): class GlobalBNNet(nn.Cell): def __init__(self): super(GlobalBNNet, self).__init__() - self.bn = nn.GlobalBatchNorm(num_features = 2, group = 4) + self.bn = nn.GlobalBatchNorm(num_features = 2, group = 2) def construct(self, x): return self.bn(x) -def test_gloabl_bn(): +def test_global_bn(): init("hccl") + size = 4 + context.set_context(mode=context.GRAPH_MODE) + context.reset_auto_parallel_context() + context.set_auto_parallel_context(parallel_mode=ParallelMode.DATA_PARALLEL, + device_num=size, parameter_broadcast=True) net = GlobalBNNet() input_data = Tensor(np.array([[2.4, 2.1], [3.2, 5.4]], dtype=np.float32)) net.set_train() From b5e98042c531650c639c9975b9709c384cfde933 Mon Sep 17 00:00:00 2001 From: zhaojichen Date: Thu, 16 Apr 2020 10:06:21 -0400 Subject: [PATCH 03/10] add global batch normalization --- mindspore/nn/layer/normalization.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/mindspore/nn/layer/normalization.py b/mindspore/nn/layer/normalization.py index 32e4be5998..dddf32ec48 100644 --- a/mindspore/nn/layer/normalization.py +++ b/mindspore/nn/layer/normalization.py @@ -111,8 +111,8 @@ class _BatchNorm(Cell): def list_group(self, world_rank, group_size): if group_size > get_local_rank_size(): - raise ValueError("group size can not be greater than local rank size, group size is {}, local_rank_size is {}".format( - group_size, get_local_rank_size())) + raise ValueError("group size can not be greater than local rank size, group size is {}, " + "local_rank_size is {}".format(group_size, get_local_rank_size())) if len(world_rank) % group_size != 0: raise ValueError("please make your group size correct.") world_rank_list = zip(*(iter(world_rank),) *group_size) From 616b9ea394d4d90b8caf6e5fc2453eb822ead721 Mon Sep 17 00:00:00 2001 From: zhaojichen Date: Thu, 16 Apr 2020 10:34:41 -0400 Subject: [PATCH 04/10] add global batch normalization --- mindspore/nn/layer/normalization.py | 4 ++-- tests/ut/python/hccl_test/manage/api.py | 8 ++++++++ 2 files changed, 10 insertions(+), 2 deletions(-) diff --git a/mindspore/nn/layer/normalization.py b/mindspore/nn/layer/normalization.py index dddf32ec48..2b55147cf1 100644 --- a/mindspore/nn/layer/normalization.py +++ b/mindspore/nn/layer/normalization.py @@ -22,7 +22,7 @@ import mindspore.common.dtype as mstype import mindspore.context as context from mindspore._checkparam import check_bool, check_typename from mindspore._extends import cell_attr_register -from mindspore.communication.management import get_local_rank_size, get_rank +from mindspore.communication.management import get_group_size, get_rank from mindspore.communication import management from mindspore._checkparam import check_int_positive from ..cell import Cell @@ -71,7 +71,7 @@ class _BatchNorm(Cell): self.group = check_int_positive(group) if self.group != 1: self.rank_id = get_rank() - self.rank_size = get_local_rank_size() + self.rank_size = get_group_size() self.device_list = [i for i in range(0, self.rank_size)] self.rank_list = self.list_group(self.device_list, self.group) self.rank_list_idx = len(self.rank_list) diff --git a/tests/ut/python/hccl_test/manage/api.py b/tests/ut/python/hccl_test/manage/api.py index 8dac167a3f..b684df5263 100644 --- a/tests/ut/python/hccl_test/manage/api.py +++ b/tests/ut/python/hccl_test/manage/api.py @@ -65,6 +65,14 @@ def get_rank_size(group=None): return int(group.split("-")[0]) raise ValueError +def get_group_size(group=None): + hccl = Hccl() + if group is None: + return hccl.rank_size + if isinstance(group, str): + return int(group.split("-")[0]) + raise ValueError + # pylint: disable=unused-argument def get_world_rank_from_group_rank(group, group_rank_id): return group_rank_id From d2b04664cad59608eb4754d2df93eaeaf84d7aca Mon Sep 17 00:00:00 2001 From: zhaojichen Date: Thu, 16 Apr 2020 20:58:59 -0400 Subject: [PATCH 05/10] add global batch normalization --- mindspore/nn/layer/normalization.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/mindspore/nn/layer/normalization.py b/mindspore/nn/layer/normalization.py index 2b55147cf1..c85b945a0d 100644 --- a/mindspore/nn/layer/normalization.py +++ b/mindspore/nn/layer/normalization.py @@ -110,9 +110,9 @@ class _BatchNorm(Cell): raise NotImplementedError def list_group(self, world_rank, group_size): - if group_size > get_local_rank_size(): + if group_size > get_group_size(): raise ValueError("group size can not be greater than local rank size, group size is {}, " - "local_rank_size is {}".format(group_size, get_local_rank_size())) + "local_rank_size is {}".format(group_size, get_group_size())) if len(world_rank) % group_size != 0: raise ValueError("please make your group size correct.") world_rank_list = zip(*(iter(world_rank),) *group_size) From d8bd5a09c4281c467e2576d140056c29b6826867 Mon Sep 17 00:00:00 2001 From: zhaojichen Date: Thu, 16 Apr 2020 21:13:45 -0400 Subject: [PATCH 06/10] add global batch normalization --- tests/ut/python/hccl_test/manage/api.py | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/tests/ut/python/hccl_test/manage/api.py b/tests/ut/python/hccl_test/manage/api.py index b684df5263..04ce7da6d5 100644 --- a/tests/ut/python/hccl_test/manage/api.py +++ b/tests/ut/python/hccl_test/manage/api.py @@ -21,6 +21,7 @@ class Hccl(): _instance = None _rank_id = 0 _rank_size = 1 + _group_size = 4 def __init__(self): pass @@ -47,6 +48,10 @@ class Hccl(): def rank_size(self): return self._rank_size + @property + def group_size(self): + return self._group_size + @rank_size.setter def rank_size(self, size): self._rank_size = size @@ -68,7 +73,7 @@ def get_rank_size(group=None): def get_group_size(group=None): hccl = Hccl() if group is None: - return hccl.rank_size + return hccl.group_size if isinstance(group, str): return int(group.split("-")[0]) raise ValueError From f7872774f3ebb44e345af364bac9230adf6aac32 Mon Sep 17 00:00:00 2001 From: zhaojichen Date: Fri, 17 Apr 2020 08:33:40 -0400 Subject: [PATCH 07/10] add global batch normalization --- mindspore/nn/layer/normalization.py | 2 +- tests/ut/python/nn/test_batchnorm.py | 3 +-- 2 files changed, 2 insertions(+), 3 deletions(-) diff --git a/mindspore/nn/layer/normalization.py b/mindspore/nn/layer/normalization.py index c85b945a0d..04de71f71c 100644 --- a/mindspore/nn/layer/normalization.py +++ b/mindspore/nn/layer/normalization.py @@ -79,7 +79,7 @@ class _BatchNorm(Cell): if self.rank_id in self.rank_list[i] and self.group != 1: self.is_global = True management.create_group('group' + str(i), self.rank_list[i]) - self.all_reduce = _GlobalBNHelper('group' + str(i)) + self.all_reduce = P.AllReduce(P.ReduceOp.SUM, 'group' + str(i)).add_prim_attr('fusion', 1) self.shape = P.Shape() self.reduce_mean = P.ReduceMean() self.square = P.Square() diff --git a/tests/ut/python/nn/test_batchnorm.py b/tests/ut/python/nn/test_batchnorm.py index 24f0de85f7..b6e27e6950 100644 --- a/tests/ut/python/nn/test_batchnorm.py +++ b/tests/ut/python/nn/test_batchnorm.py @@ -90,5 +90,4 @@ def test_global_bn(): device_num=size, parameter_broadcast=True) net = GlobalBNNet() input_data = Tensor(np.array([[2.4, 2.1], [3.2, 5.4]], dtype=np.float32)) - net.set_train() - out = net(input_data) + _executor.compile(net,input_data) From c5120e770caa1e04aeefb287497a9674d2a05127 Mon Sep 17 00:00:00 2001 From: zhaojichen Date: Fri, 17 Apr 2020 08:37:08 -0400 Subject: [PATCH 08/10] add global batch normalization --- mindspore/nn/layer/normalization.py | 8 -------- 1 file changed, 8 deletions(-) diff --git a/mindspore/nn/layer/normalization.py b/mindspore/nn/layer/normalization.py index 04de71f71c..6456a3603d 100644 --- a/mindspore/nn/layer/normalization.py +++ b/mindspore/nn/layer/normalization.py @@ -27,14 +27,6 @@ from mindspore.communication import management from mindspore._checkparam import check_int_positive from ..cell import Cell -class _GlobalBNHelper(Cell): - def __init__(self, group): - super(_GlobalBNHelper, self).__init__() - self.group = group - self.reduce = P.AllReduce(P.ReduceOp.SUM, group=self.group).add_prim_attr('fusion', 1) - def construct(self, x): - x = self.reduce(x) - return x class _BatchNorm(Cell): """Batch Normalization base class.""" From 17e27824c54c163cb11bd1d4b8d3b257b149b123 Mon Sep 17 00:00:00 2001 From: zhaojichen Date: Fri, 17 Apr 2020 21:47:39 -0400 Subject: [PATCH 09/10] add global batch normalization --- tests/ut/python/nn/test_batchnorm.py | 18 ------------------ 1 file changed, 18 deletions(-) diff --git a/tests/ut/python/nn/test_batchnorm.py b/tests/ut/python/nn/test_batchnorm.py index b6e27e6950..10b4cb00a1 100644 --- a/tests/ut/python/nn/test_batchnorm.py +++ b/tests/ut/python/nn/test_batchnorm.py @@ -73,21 +73,3 @@ def test_compile_groupnorm(): net = nn.GroupNorm(16, 64) input_data = Tensor(np.random.rand(1,64,256,256).astype(np.float32)) _executor.compile(net, input_data) - -class GlobalBNNet(nn.Cell): - def __init__(self): - super(GlobalBNNet, self).__init__() - self.bn = nn.GlobalBatchNorm(num_features = 2, group = 2) - def construct(self, x): - return self.bn(x) - -def test_global_bn(): - init("hccl") - size = 4 - context.set_context(mode=context.GRAPH_MODE) - context.reset_auto_parallel_context() - context.set_auto_parallel_context(parallel_mode=ParallelMode.DATA_PARALLEL, - device_num=size, parameter_broadcast=True) - net = GlobalBNNet() - input_data = Tensor(np.array([[2.4, 2.1], [3.2, 5.4]], dtype=np.float32)) - _executor.compile(net,input_data) From 97e250d4f1898000cbc44a2620cd36f6e52abd2f Mon Sep 17 00:00:00 2001 From: zhaojichen Date: Fri, 17 Apr 2020 21:51:45 -0400 Subject: [PATCH 10/10] add global batch normalization --- tests/ut/python/hccl_test/manage/api.py | 13 ------------- tests/ut/python/nn/test_batchnorm.py | 3 --- 2 files changed, 16 deletions(-) diff --git a/tests/ut/python/hccl_test/manage/api.py b/tests/ut/python/hccl_test/manage/api.py index 04ce7da6d5..8dac167a3f 100644 --- a/tests/ut/python/hccl_test/manage/api.py +++ b/tests/ut/python/hccl_test/manage/api.py @@ -21,7 +21,6 @@ class Hccl(): _instance = None _rank_id = 0 _rank_size = 1 - _group_size = 4 def __init__(self): pass @@ -48,10 +47,6 @@ class Hccl(): def rank_size(self): return self._rank_size - @property - def group_size(self): - return self._group_size - @rank_size.setter def rank_size(self, size): self._rank_size = size @@ -70,14 +65,6 @@ def get_rank_size(group=None): return int(group.split("-")[0]) raise ValueError -def get_group_size(group=None): - hccl = Hccl() - if group is None: - return hccl.group_size - if isinstance(group, str): - return int(group.split("-")[0]) - raise ValueError - # pylint: disable=unused-argument def get_world_rank_from_group_rank(group, group_rank_id): return group_rank_id diff --git a/tests/ut/python/nn/test_batchnorm.py b/tests/ut/python/nn/test_batchnorm.py index 10b4cb00a1..e73b7ebbf0 100644 --- a/tests/ut/python/nn/test_batchnorm.py +++ b/tests/ut/python/nn/test_batchnorm.py @@ -19,9 +19,6 @@ import pytest import mindspore.nn as nn from mindspore.common.api import _executor from mindspore import Tensor, Parameter -from mindspore.communication.management import init -from mindspore import context -from mindspore import ParallelMode def test_bn_pars_valid1():