Browse Source

amend bn

pull/15581/head
jiangzhenguang 5 years ago
parent
commit
3a85c18ff0
1 changed files with 28 additions and 12 deletions
  1. +28
    -12
      mindspore/nn/layer/normalization.py

+ 28
- 12
mindspore/nn/layer/normalization.py View File

@@ -39,6 +39,7 @@ __all__ = ['BatchNorm1d', 'BatchNorm2d', 'BatchNorm3d', 'LayerNorm', 'GroupNorm'


SYNC_BN_GROUP_NAME = "" SYNC_BN_GROUP_NAME = ""



class _BatchNorm(Cell): class _BatchNorm(Cell):
"""Batch Normalization base class.""" """Batch Normalization base class."""


@@ -69,6 +70,8 @@ class _BatchNorm(Cell):
if context.get_context("device_target") != "GPU" and self.format == "NHWC": if context.get_context("device_target") != "GPU" and self.format == "NHWC":
raise ValueError("NHWC format only support in GPU target.") raise ValueError("NHWC format only support in GPU target.")
self.use_batch_statistics = use_batch_statistics self.use_batch_statistics = use_batch_statistics
if self.use_batch_statistics is not None and not isinstance(self.use_batch_statistics, bool):
raise ValueError("use_batch_statistics should be a boolean value or None.")
self.num_features = num_features self.num_features = num_features
self.eps = eps self.eps = eps
self.moving_mean = Parameter(initializer( self.moving_mean = Parameter(initializer(
@@ -95,7 +98,7 @@ class _BatchNorm(Cell):
if self.rank_id in self.rank_list[i]: if self.rank_id in self.rank_list[i]:
self.is_global = True self.is_global = True
if SYNC_BN_GROUP_NAME == "": if SYNC_BN_GROUP_NAME == "":
SYNC_BN_GROUP_NAME = "sync_bn_group"+ str(i)
SYNC_BN_GROUP_NAME = "sync_bn_group" + str(i)
management.create_group(SYNC_BN_GROUP_NAME, self.rank_list[i]) management.create_group(SYNC_BN_GROUP_NAME, self.rank_list[i])
# for SyncBatchNorm # for SyncBatchNorm
if self.process_groups != 0: if self.process_groups != 0:
@@ -105,7 +108,7 @@ class _BatchNorm(Cell):
validator.check_isinstance("process_groups", self.process_groups, list) validator.check_isinstance("process_groups", self.process_groups, list)
self._check_rank_ids(self.process_groups, self.rank_size) self._check_rank_ids(self.process_groups, self.rank_size)
for i in range(len(self.process_groups)): for i in range(len(self.process_groups)):
validator.check_isinstance("process_groups[" + str(i) +"]", self.process_groups[i], list)
validator.check_isinstance("process_groups[" + str(i) + "]", self.process_groups[i], list)
self.group_device_num = len(self.process_groups[i]) self.group_device_num = len(self.process_groups[i])
if self.rank_id in self.process_groups[i] and self.group_device_num > 1: if self.rank_id in self.process_groups[i] and self.group_device_num > 1:
self.is_global = True self.is_global = True
@@ -180,11 +183,20 @@ class _BatchNorm(Cell):
def construct(self, x): def construct(self, x):
_shape_check_bn(self.shape(x), self.input_dims) _shape_check_bn(self.shape(x), self.input_dims)
if self.use_batch_statistics is None: if self.use_batch_statistics is None:
flag = self.training
else:
flag = self.use_batch_statistics

if flag:
if self.training:
return self.bn_train(x,
self.gamma,
self.beta,
self.moving_mean,
self.moving_variance)[0]
if not self.training:
return self.bn_infer(x,
self.gamma,
self.beta,
self.moving_mean,
self.moving_variance)[0]

if self.use_batch_statistics is True:
return self.bn_train(x, return self.bn_train(x,
self.gamma, self.gamma,
self.beta, self.beta,
@@ -365,10 +377,14 @@ class BatchNorm2d(_BatchNorm):
The values of str refer to the function `initializer` including 'zeros', 'ones', etc. Default: 'zeros'. The values of str refer to the function `initializer` including 'zeros', 'ones', etc. Default: 'zeros'.
moving_var_init (Union[Tensor, str, Initializer, numbers.Number]): Initializer for the moving variance. moving_var_init (Union[Tensor, str, Initializer, numbers.Number]): Initializer for the moving variance.
The values of str refer to the function `initializer` including 'zeros', 'ones', etc. Default: 'ones'. The values of str refer to the function `initializer` including 'zeros', 'ones', etc. Default: 'ones'.
use_batch_statistics (bool): If true, use the mean value and variance value of current batch data. If false,
use the mean value and variance value of specified value. If None, the training process will use the mean
and variance of current batch data and track the running mean and variance, the evaluation process will use
the running mean and variance. Default: None.
use_batch_statistics (bool):
If true, use the mean value and variance value of current batch data and track running mean
and running varance.
If false, use the mean value and variance value of specified value, and not track statistical value.
If None, The use_batch_statistics is automatically assigned process according to the training and eval mode.
During training, batchnorm2d process will be same with use_batch_statistics=True.
Contrarily, in eval, batchnorm2d process will be same with use_batch_statistics=False.
Default: None.
data_format (str): The optional value for data format, is 'NHWC' or 'NCHW'. data_format (str): The optional value for data format, is 'NHWC' or 'NCHW'.
Default: 'NCHW'. Default: 'NCHW'.


@@ -527,7 +543,7 @@ class BatchNorm3d(Cell):
def construct(self, input_x): def construct(self, input_x):
x_shape = F.shape(input_x) x_shape = F.shape(input_x)
_check_3d_shape(x_shape) _check_3d_shape(x_shape)
input_x = self.reshape(input_x, (x_shape[0], x_shape[1], x_shape[2]*x_shape[3], x_shape[4]))
input_x = self.reshape(input_x, (x_shape[0], x_shape[1], x_shape[2] * x_shape[3], x_shape[4]))
bn2d_out = self.bn2d(input_x) bn2d_out = self.bn2d(input_x)
bn3d_out = self.reshape(bn2d_out, x_shape) bn3d_out = self.reshape(bn2d_out, x_shape)
return bn3d_out return bn3d_out


Loading…
Cancel
Save