From ebe6efff719bec947f2f73d42cdd542ff67f932e Mon Sep 17 00:00:00 2001 From: zhaojichen Date: Thu, 16 Apr 2020 01:42:32 -0400 Subject: [PATCH 1/5] Add Group Normalization --- mindspore/nn/layer/__init__.py | 4 +- mindspore/nn/layer/normalization.py | 73 +++++++++++++++++++++++++++- tests/ut/python/nn/test_batchnorm.py | 12 +++++ 3 files changed, 85 insertions(+), 4 deletions(-) diff --git a/mindspore/nn/layer/__init__.py b/mindspore/nn/layer/__init__.py index 3d729edcd3..cf601f03ff 100644 --- a/mindspore/nn/layer/__init__.py +++ b/mindspore/nn/layer/__init__.py @@ -18,7 +18,7 @@ Layer. The high-level components(Cells) used to construct the neural network. """ from .activation import Softmax, LogSoftmax, ReLU, ReLU6, Tanh, GELU, ELU, Sigmoid, PReLU, get_activation, LeakyReLU, HSigmoid, HSwish -from .normalization import BatchNorm1d, BatchNorm2d, LayerNorm +from .normalization import BatchNorm1d, BatchNorm2d, LayerNorm, GroupNorm from .container import SequentialCell, CellList from .conv import Conv2d, Conv2dTranspose from .lstm import LSTM @@ -29,7 +29,7 @@ from .image import ImageGradients, SSIM __all__ = ['Softmax', 'LogSoftmax', 'ReLU', 'ReLU6', 'Tanh', 'GELU', 'Sigmoid', 'PReLU', 'get_activation', 'LeakyReLU', 'HSigmoid', 'HSwish', 'ELU', - 'BatchNorm1d', 'BatchNorm2d', 'LayerNorm', + 'BatchNorm1d', 'BatchNorm2d', 'LayerNorm', 'GroupNorm', 'SequentialCell', 'CellList', 'Conv2d', 'Conv2dTranspose', 'LSTM', diff --git a/mindspore/nn/layer/normalization.py b/mindspore/nn/layer/normalization.py index 2df064353f..cac73d239e 100644 --- a/mindspore/nn/layer/normalization.py +++ b/mindspore/nn/layer/normalization.py @@ -18,8 +18,9 @@ from mindspore.ops import functional as F from mindspore.common.parameter import Parameter from mindspore.common.initializer import initializer from mindspore.common.tensor import Tensor -import mindspore.common.dtype as DT +import mindspore.common.dtype as mstype import mindspore.context as context +from mindspore._checkparam import check_int_positive, check_bool,check_typename from mindspore._extends import cell_attr_register from ..cell import Cell @@ -58,7 +59,7 @@ class _BatchNorm(Cell): if context.get_context("enable_ge"): self.is_ge_backend = True - self.momentum = Tensor(1.0 - momentum, DT.float32) + self.momentum = Tensor(1.0 - momentum, mstype.float32) self.bn_train = P.BatchNorm(is_training=True, epsilon=self.eps) else: @@ -289,3 +290,71 @@ class LayerNorm(Cell): s = 'normalized_shape={}, begin_norm_axis={}, begin_params_axis={}, gamma{}, beta={}'.format( self.normalized_shape, self.begin_norm_axis, self.begin_params_axis, self.gamma, self.beta) return s + +class GroupNorm(Cell): + r""" + Group Normalization over a mini-batch of inputs. + + Group normalization is widely used in recurrent neural networks. It applies + normalization over a mini-batch of inputs for each single training case as described + in the paper `Group Normalization `_. Group normalization + divides the channels into groups and computes within each group the mean and variance for normalization, + and it performs very stable over a wide range of batch size. It can be described using the following formula. + + .. math:: + y = \frac{x - \mathrm{E}[x]}{\sqrt{\mathrm{Var}[x] + \epsilon}} * \gamma + \beta + + Args: + num_groups (int): The number of groups to be divided along the channel dimension. + num_channels (int): The number of channels per group. + eps (float): A value added to the denominator for numerical stability. Default: 1e-5. + affine (bool): A bool value, this layer will has learnable affine parameters when set to true. Default: True. + + Inputs: + - **input_x** (Tensor) - The input feature with shape [N, C, H, W]. + + Outputs: + Tensor, the normalized and scaled offset tensor, has the same shape and data type as the `input_x`. + + Examples: + >>> goup_norm_op = nn.GroupNorm(16, 64) + >>> x = Tensor(np.ones([1, 64, 256, 256], np.float32)) + >>> goup_norm_op(x) + """ + def __init__(self, num_groups, num_channels, eps=1e-05, affine=True): + super(GroupNorm, self).__init__() + self.num_groups = check_int_positive(num_groups) + self.num_channels = check_int_positive(num_channels) + if num_channels % num_groups != 0: + raise ValueError("num_channels should be divided by num_groups") + self.eps = Tensor(check_typename('eps', eps, (float,)),mstype.float32) + self.affine = check_bool(affine) + + gamma = initializer('ones', [num_channels, 1, 1], mstype.float32) + beta = initializer('zeros', [num_channels, 1, 1], mstype.float32) + if self.affine: + self.gamma = Parameter(gamma, name='gamma') + self.beta = Parameter(beta, name='beta') + else: + self.gamma = gamma + self.beta = beta + self.shape = F.shape + self.reshape = F.reshape + self.reduce_mean = P.ReduceMean(keep_dims=True) + self.square = F.square + self.reduce_sum = P.ReduceSum(keep_dims=True) + self.sqrt = P.Sqrt() + + def construct(self, x): + batch,channel,height,width = self.shape(x) + x = self.reshape(x,(batch, self.num_groups,channel*height*width/self.num_groups)) + mean = self.reduce_mean(x, 2) + var = self.reduce_sum(self.square(x - mean), 2) / (channel * height * width / self.num_groups - 1) + std = self.sqrt(var + self.eps) + x = (x - mean) / std + x = self.reshape(x, (batch, channel, height, width)) + output = x * self.gamma + self.beta + return output + + def extend_repr(self): + return 'num_groups={}, num_channels={}'.format(self.num_groups,self.num_channels) \ No newline at end of file diff --git a/tests/ut/python/nn/test_batchnorm.py b/tests/ut/python/nn/test_batchnorm.py index eaafdd81b4..efccfa4b33 100644 --- a/tests/ut/python/nn/test_batchnorm.py +++ b/tests/ut/python/nn/test_batchnorm.py @@ -56,3 +56,15 @@ def test_compile(): net = Net() input_data = Tensor(np.random.randint(0, 255, [1, 3, 224, 224]).astype(np.float32)) _executor.compile(net, input_data) + +class GroupNet(nn.Cell): + def __init__(self): + super(GroupNet, self).__init__() + self.group_bn = nn.GroupNorm() + def construct(self, x): + return self.group_bn(x) + +def test_compile_groupnorm(): + net = nn.GroupNorm(16, 64) + input_data = Tensor(np.random.rand(1,64,256,256).astype(np.float32)) + _executor.compile(net, input_data) \ No newline at end of file From 898acc3201f65b3943afc6b7a187857ad4471a06 Mon Sep 17 00:00:00 2001 From: zhaojichen Date: Thu, 16 Apr 2020 01:51:36 -0400 Subject: [PATCH 2/5] Add Group Normalization --- mindspore/nn/layer/normalization.py | 12 +++++++----- 1 file changed, 7 insertions(+), 5 deletions(-) diff --git a/mindspore/nn/layer/normalization.py b/mindspore/nn/layer/normalization.py index cac73d239e..58f926cdcf 100644 --- a/mindspore/nn/layer/normalization.py +++ b/mindspore/nn/layer/normalization.py @@ -20,7 +20,7 @@ from mindspore.common.initializer import initializer from mindspore.common.tensor import Tensor import mindspore.common.dtype as mstype import mindspore.context as context -from mindspore._checkparam import check_int_positive, check_bool,check_typename +from mindspore._checkparam import check_int_positive, check_bool, check_typename from mindspore._extends import cell_attr_register from ..cell import Cell @@ -327,7 +327,7 @@ class GroupNorm(Cell): self.num_channels = check_int_positive(num_channels) if num_channels % num_groups != 0: raise ValueError("num_channels should be divided by num_groups") - self.eps = Tensor(check_typename('eps', eps, (float,)),mstype.float32) + self.eps = Tensor(check_typename('eps', eps, (float,)), mstype.float32) self.affine = check_bool(affine) gamma = initializer('ones', [num_channels, 1, 1], mstype.float32) @@ -346,8 +346,8 @@ class GroupNorm(Cell): self.sqrt = P.Sqrt() def construct(self, x): - batch,channel,height,width = self.shape(x) - x = self.reshape(x,(batch, self.num_groups,channel*height*width/self.num_groups)) + batch, channel, height,width = self.shape(x) + x = self.reshape(x, (batch, self.num_groups, channel*height*width/self.num_groups)) mean = self.reduce_mean(x, 2) var = self.reduce_sum(self.square(x - mean), 2) / (channel * height * width / self.num_groups - 1) std = self.sqrt(var + self.eps) @@ -357,4 +357,6 @@ class GroupNorm(Cell): return output def extend_repr(self): - return 'num_groups={}, num_channels={}'.format(self.num_groups,self.num_channels) \ No newline at end of file + """Display instance object as string.""" + s = 'num_groups={}, num_channels={}'.format(self.num_groups, self.num_channels) + return s \ No newline at end of file From 0b7de6968fe1dcbc9c78f292bdc2b673502b5f2c Mon Sep 17 00:00:00 2001 From: zhaojichen Date: Thu, 16 Apr 2020 01:56:12 -0400 Subject: [PATCH 3/5] Add Group Normalization --- mindspore/nn/layer/normalization.py | 4 ++-- tests/ut/python/nn/test_batchnorm.py | 2 +- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/mindspore/nn/layer/normalization.py b/mindspore/nn/layer/normalization.py index 58f926cdcf..b286eaae1b 100644 --- a/mindspore/nn/layer/normalization.py +++ b/mindspore/nn/layer/normalization.py @@ -346,7 +346,7 @@ class GroupNorm(Cell): self.sqrt = P.Sqrt() def construct(self, x): - batch, channel, height,width = self.shape(x) + batch, channel, height, width = self.shape(x) x = self.reshape(x, (batch, self.num_groups, channel*height*width/self.num_groups)) mean = self.reduce_mean(x, 2) var = self.reduce_sum(self.square(x - mean), 2) / (channel * height * width / self.num_groups - 1) @@ -359,4 +359,4 @@ class GroupNorm(Cell): def extend_repr(self): """Display instance object as string.""" s = 'num_groups={}, num_channels={}'.format(self.num_groups, self.num_channels) - return s \ No newline at end of file + return s diff --git a/tests/ut/python/nn/test_batchnorm.py b/tests/ut/python/nn/test_batchnorm.py index efccfa4b33..4bd8c996d6 100644 --- a/tests/ut/python/nn/test_batchnorm.py +++ b/tests/ut/python/nn/test_batchnorm.py @@ -67,4 +67,4 @@ class GroupNet(nn.Cell): def test_compile_groupnorm(): net = nn.GroupNorm(16, 64) input_data = Tensor(np.random.rand(1,64,256,256).astype(np.float32)) - _executor.compile(net, input_data) \ No newline at end of file + _executor.compile(net, input_data) From 04c522d0c6edfe4074c9bdecf16146fbe617644d Mon Sep 17 00:00:00 2001 From: zhaojichen Date: Thu, 16 Apr 2020 01:59:32 -0400 Subject: [PATCH 4/5] Add Group Normalization --- tests/ut/python/nn/test_batchnorm.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/tests/ut/python/nn/test_batchnorm.py b/tests/ut/python/nn/test_batchnorm.py index 4bd8c996d6..e73b7ebbf0 100644 --- a/tests/ut/python/nn/test_batchnorm.py +++ b/tests/ut/python/nn/test_batchnorm.py @@ -57,6 +57,7 @@ def test_compile(): input_data = Tensor(np.random.randint(0, 255, [1, 3, 224, 224]).astype(np.float32)) _executor.compile(net, input_data) + class GroupNet(nn.Cell): def __init__(self): super(GroupNet, self).__init__() @@ -64,6 +65,7 @@ class GroupNet(nn.Cell): def construct(self, x): return self.group_bn(x) + def test_compile_groupnorm(): net = nn.GroupNorm(16, 64) input_data = Tensor(np.random.rand(1,64,256,256).astype(np.float32)) From 77fd2e841eab957a08b766a210e2db45942263a4 Mon Sep 17 00:00:00 2001 From: zhaojichen Date: Thu, 16 Apr 2020 03:54:32 -0400 Subject: [PATCH 5/5] Add Group Normalization --- mindspore/nn/layer/normalization.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/mindspore/nn/layer/normalization.py b/mindspore/nn/layer/normalization.py index b286eaae1b..4aafaf031e 100644 --- a/mindspore/nn/layer/normalization.py +++ b/mindspore/nn/layer/normalization.py @@ -20,7 +20,7 @@ from mindspore.common.initializer import initializer from mindspore.common.tensor import Tensor import mindspore.common.dtype as mstype import mindspore.context as context -from mindspore._checkparam import check_int_positive, check_bool, check_typename +from mindspore._checkparam import check_int_positive, check_bool, check_typename from mindspore._extends import cell_attr_register from ..cell import Cell @@ -293,7 +293,7 @@ class LayerNorm(Cell): class GroupNorm(Cell): r""" - Group Normalization over a mini-batch of inputs. + Group Normalization over a mini-batch of inputs. Group normalization is widely used in recurrent neural networks. It applies normalization over a mini-batch of inputs for each single training case as described