From d1002883592360410c0efcbd05252ad5983b2c36 Mon Sep 17 00:00:00 2001 From: zhaojichen Date: Sun, 26 Apr 2020 02:24:00 -0400 Subject: [PATCH 1/4] fix groupnorm bug and change globalbn parameter name --- mindspore/nn/layer/normalization.py | 20 ++++++++++---------- 1 file changed, 10 insertions(+), 10 deletions(-) diff --git a/mindspore/nn/layer/normalization.py b/mindspore/nn/layer/normalization.py index 3ef2381ba1..09a0b4bb27 100644 --- a/mindspore/nn/layer/normalization.py +++ b/mindspore/nn/layer/normalization.py @@ -41,7 +41,7 @@ class _BatchNorm(Cell): moving_mean_init='zeros', moving_var_init='ones', use_batch_statistics=True, - group=1): + device_num_each_group=1): super(_BatchNorm, self).__init__() if num_features < 1: raise ValueError("num_features must be at least 1") @@ -60,7 +60,7 @@ class _BatchNorm(Cell): gamma_init, num_features), name="gamma", requires_grad=affine) self.beta = Parameter(initializer( beta_init, num_features), name="beta", requires_grad=affine) - self.group = check_int_positive(group) + self.group = check_int_positive(device_num_each_group) if self.group != 1: self.rank_id = get_rank() self.rank_size = get_group_size() @@ -324,7 +324,7 @@ class GlobalBatchNorm(_BatchNorm): Args: num_features (int): `C` from an expected input of size (N, C, H, W). - group (int): The number of device in each group. + device_num_each_group (int): The number of device in each group. eps (float): A value added to the denominator for numerical stability. Default: 1e-5. momentum (float): A floating hyperparameter of the momentum for the running_mean and running_var computation. Default: 0.9. @@ -364,7 +364,7 @@ class GlobalBatchNorm(_BatchNorm): moving_mean_init='zeros', moving_var_init='ones', use_batch_statistics=True, - group=1): + device_num_each_group=1): super(GlobalBatchNorm, self).__init__(num_features, eps, momentum, @@ -374,8 +374,8 @@ class GlobalBatchNorm(_BatchNorm): moving_mean_init, moving_var_init, use_batch_statistics, - group) - self.group = check_int_positive(group) + device_num_each_group) + self.group = check_int_positive(device_num_each_group) if self.group <= 1: raise ValueError("the number of group must be greater than 1.") def _check_data_dim(self, x): @@ -482,17 +482,17 @@ class GroupNorm(Cell): >>> x = Tensor(np.ones([1, 64, 256, 256], np.float32)) >>> goup_norm_op(x) """ - def __init__(self, num_groups, num_channels, eps=1e-05, affine=True): + def __init__(self, num_groups, num_channels, eps=1e-05, affine=True, gamma_init='ones', beta_init='zeros'): super(GroupNorm, self).__init__() self.num_groups = check_int_positive(num_groups) self.num_channels = check_int_positive(num_channels) if num_channels % num_groups != 0: raise ValueError("num_channels should be divided by num_groups") - self.eps = Tensor(check_typename('eps', eps, (float,)), mstype.float32) + self.eps = check_typename('eps', eps, (float,)) self.affine = check_bool(affine) - gamma = initializer('ones', [num_channels, 1, 1], mstype.float32) - beta = initializer('zeros', [num_channels, 1, 1], mstype.float32) + gamma = initializer(gamma_init, [num_channels, 1, 1]) + beta = initializer(beta_init, [num_channels, 1, 1]) if self.affine: self.gamma = Parameter(gamma, name='gamma') self.beta = Parameter(beta, name='beta') From f53bd08c80810659caa094baf3e137102054b1f7 Mon Sep 17 00:00:00 2001 From: zhaojichen Date: Sun, 26 Apr 2020 08:08:56 -0400 Subject: [PATCH 2/4] fix groupnorm bug and change globalbn parameter name --- mindspore/train/model.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/mindspore/train/model.py b/mindspore/train/model.py index 698105889a..3fccb7aa2b 100755 --- a/mindspore/train/model.py +++ b/mindspore/train/model.py @@ -83,7 +83,7 @@ class Model: >>> return out >>> >>> net = Net() - >>> loss = nn.SoftmaxCrossEntropyWithLogits() + >>> loss = nn.SoftmaxCrossEntropyWithLogits(is_grad=False, sparse=True) >>> optim = Momentum(params=net.trainable_params(), learning_rate=0.1, momentum=0.9) >>> model = Model(net, loss_fn=loss, optimizer=optim, metrics=None) >>> dataset = get_dataset() @@ -395,7 +395,7 @@ class Model: Examples: >>> dataset = get_dataset() >>> net = Net() - >>> loss = nn.SoftmaxCrossEntropyWithLogits() + >>> loss = nn.SoftmaxCrossEntropyWithLogits(is_grad=False, sparse=True) >>> loss_scale_manager = FixedLossScaleManager() >>> optim = Momentum(params=net.trainable_params(), learning_rate=0.1, momentum=0.9) >>> model = Model(net, loss_fn=loss, optimizer=optim, metrics=None, loss_scale_manager=loss_scale_manager) @@ -518,7 +518,7 @@ class Model: Examples: >>> dataset = get_dataset() >>> net = Net() - >>> loss = nn.SoftmaxCrossEntropyWithLogits() + >>> loss = nn.SoftmaxCrossEntropyWithLogits(is_grad=False, sparse=True) >>> model = Model(net, loss_fn=loss, optimizer=None, metrics={'acc'}) >>> model.eval(dataset) """ From f4de3741243c0cb4b519ff4faed6efd81f100b23 Mon Sep 17 00:00:00 2001 From: zhaojichen Date: Sun, 26 Apr 2020 22:11:32 -0400 Subject: [PATCH 3/4] fix groupnorm bug and change globalbn parameter name --- mindspore/nn/layer/normalization.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/mindspore/nn/layer/normalization.py b/mindspore/nn/layer/normalization.py index 09a0b4bb27..1066d7f94c 100644 --- a/mindspore/nn/layer/normalization.py +++ b/mindspore/nn/layer/normalization.py @@ -350,7 +350,7 @@ class GlobalBatchNorm(_BatchNorm): Tensor, the normalized, scaled, offset tensor, of shape :math:`(N, C_{out}, H_{out}, W_{out})`. Examples: - >>> global_bn_op = nn.GlobalBatchNorm(num_features=3, group=4) + >>> global_bn_op = nn.GlobalBatchNorm(num_features=3, device_num_each_group=4) >>> input = Tensor(np.random.randint(0, 255, [1, 3, 224, 224]), mindspore.float32) >>> global_bn_op(input) """ From 56323f54a3b6857f17f7439842ed73147c7a9b09 Mon Sep 17 00:00:00 2001 From: zhaojichen Date: Mon, 27 Apr 2020 02:55:11 -0400 Subject: [PATCH 4/4] fix groupnorm bug and change globalbn parameter name --- mindspore/nn/layer/normalization.py | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/mindspore/nn/layer/normalization.py b/mindspore/nn/layer/normalization.py index 1066d7f94c..ddd1bab1bf 100644 --- a/mindspore/nn/layer/normalization.py +++ b/mindspore/nn/layer/normalization.py @@ -17,6 +17,7 @@ from mindspore.ops import operations as P from mindspore.ops import functional as F from mindspore.common.parameter import Parameter from mindspore.common.initializer import initializer +from mindspore.ops.primitive import constexpr from mindspore.common.tensor import Tensor import mindspore.common.dtype as mstype import mindspore.context as context @@ -166,6 +167,10 @@ class _BatchNorm(Cell): return 'num_features={}, eps={}, momentum={}, gamma={}, beta={}, moving_mean={}, moving_variance={}'.format( self.num_features, self.eps, self.momentum, self.gamma, self.beta, self.moving_mean, self.moving_variance) +@constexpr +def _channel_check(channel, num_channel): + if channel != num_channel: + raise ValueError("the input channel is not equal with num_channel") class BatchNorm1d(_BatchNorm): r""" @@ -508,6 +513,7 @@ class GroupNorm(Cell): def construct(self, x): batch, channel, height, width = self.shape(x) + _channel_check(channel, self.num_channels) x = self.reshape(x, (batch, self.num_groups, channel*height*width/self.num_groups)) mean = self.reduce_mean(x, 2) var = self.reduce_sum(self.square(x - mean), 2) / (channel * height * width / self.num_groups - 1)