Browse Source

!15581 fix bn2d

From: @jiangzg001
Reviewed-by: @tom__chen,@liangchenghui
Signed-off-by: @liangchenghui
pull/15581/MERGE
mindspore-ci-bot Gitee 4 years ago
parent
commit
058bc0c4ea
1 changed files with 28 additions and 12 deletions
  1. +28
    -12
      mindspore/nn/layer/normalization.py

+ 28
- 12
mindspore/nn/layer/normalization.py View File

@@ -39,6 +39,7 @@ __all__ = ['BatchNorm1d', 'BatchNorm2d', 'BatchNorm3d', 'LayerNorm', 'GroupNorm'

SYNC_BN_GROUP_NAME = ""


class _BatchNorm(Cell):
"""Batch Normalization base class."""

@@ -69,6 +70,8 @@ class _BatchNorm(Cell):
if context.get_context("device_target") != "GPU" and self.format == "NHWC":
raise ValueError("NHWC format only support in GPU target.")
self.use_batch_statistics = use_batch_statistics
if self.use_batch_statistics is not None and not isinstance(self.use_batch_statistics, bool):
raise ValueError("use_batch_statistics should be a boolean value or None.")
self.num_features = num_features
self.eps = eps
self.moving_mean = Parameter(initializer(
@@ -95,7 +98,7 @@ class _BatchNorm(Cell):
if self.rank_id in self.rank_list[i]:
self.is_global = True
if SYNC_BN_GROUP_NAME == "":
SYNC_BN_GROUP_NAME = "sync_bn_group"+ str(i)
SYNC_BN_GROUP_NAME = "sync_bn_group" + str(i)
management.create_group(SYNC_BN_GROUP_NAME, self.rank_list[i])
# for SyncBatchNorm
if self.process_groups != 0:
@@ -105,7 +108,7 @@ class _BatchNorm(Cell):
validator.check_isinstance("process_groups", self.process_groups, list)
self._check_rank_ids(self.process_groups, self.rank_size)
for i in range(len(self.process_groups)):
validator.check_isinstance("process_groups[" + str(i) +"]", self.process_groups[i], list)
validator.check_isinstance("process_groups[" + str(i) + "]", self.process_groups[i], list)
self.group_device_num = len(self.process_groups[i])
if self.rank_id in self.process_groups[i] and self.group_device_num > 1:
self.is_global = True
@@ -183,11 +186,20 @@ class _BatchNorm(Cell):
def construct(self, x):
_shape_check_bn(self.shape(x), self.input_dims)
if self.use_batch_statistics is None:
flag = self.training
else:
flag = self.use_batch_statistics

if flag:
if self.training:
return self.bn_train(x,
self.gamma,
self.beta,
self.moving_mean,
self.moving_variance)[0]
if not self.training:
return self.bn_infer(x,
self.gamma,
self.beta,
self.moving_mean,
self.moving_variance)[0]

if self.use_batch_statistics is True:
return self.bn_train(x,
self.gamma,
self.beta,
@@ -368,10 +380,14 @@ class BatchNorm2d(_BatchNorm):
The values of str refer to the function `initializer` including 'zeros', 'ones', etc. Default: 'zeros'.
moving_var_init (Union[Tensor, str, Initializer, numbers.Number]): Initializer for the moving variance.
The values of str refer to the function `initializer` including 'zeros', 'ones', etc. Default: 'ones'.
use_batch_statistics (bool): If true, use the mean value and variance value of current batch data. If false,
use the mean value and variance value of specified value. If None, the training process will use the mean
and variance of current batch data and track the running mean and variance, the evaluation process will use
the running mean and variance. Default: None.
use_batch_statistics (bool):
If true, use the mean value and variance value of current batch data and track running mean
and running varance.
If false, use the mean value and variance value of specified value, and not track statistical value.
If None, The use_batch_statistics is automatically assigned process according to the training and eval mode.
During training, batchnorm2d process will be same with use_batch_statistics=True.
Contrarily, in eval, batchnorm2d process will be same with use_batch_statistics=False.
Default: None.
data_format (str): The optional value for data format, is 'NHWC' or 'NCHW'.
Default: 'NCHW'.

@@ -530,7 +546,7 @@ class BatchNorm3d(Cell):
def construct(self, input_x):
x_shape = F.shape(input_x)
_check_3d_shape(x_shape)
input_x = self.reshape(input_x, (x_shape[0], x_shape[1], x_shape[2]*x_shape[3], x_shape[4]))
input_x = self.reshape(input_x, (x_shape[0], x_shape[1], x_shape[2] * x_shape[3], x_shape[4]))
bn2d_out = self.bn2d(input_x)
bn3d_out = self.reshape(bn2d_out, x_shape)
return bn3d_out


Loading…
Cancel
Save