Browse Source

fix bn train&eval loss problem

tags/v0.5.0-beta
zhaojichen 6 years ago
parent
commit
59993c4843
1 changed files with 21 additions and 11 deletions
  1. +21
    -11
      mindspore/nn/layer/normalization.py

+ 21
- 11
mindspore/nn/layer/normalization.py View File

@@ -43,7 +43,7 @@ class _BatchNorm(Cell):
beta_init='zeros',
moving_mean_init='zeros',
moving_var_init='ones',
use_batch_statistics=True,
use_batch_statistics=None,
device_num_each_group=1):
super(_BatchNorm, self).__init__()
if num_features < 1:
@@ -147,7 +147,11 @@ class _BatchNorm(Cell):
return y

def construct(self, x):
if self.training and self.use_batch_statistics:
if self.use_batch_statistics is None:
flag = self.training
else:
flag = self.use_batch_statistics
if flag:
if self.is_ge_backend and self.is_global:
axes, re_shape = _shape_infer(F.shape(x), self.num_features)
y = self._global_sync(x, axes, re_shape)
@@ -236,8 +240,10 @@ class BatchNorm1d(_BatchNorm):
moving_var_init (Union[Tensor, str, Initializer, numbers.Number]): Initializer for the moving variance.
The values of str refer to the function `initializer` including 'zeros', 'ones', 'xavier_uniform',
'he_uniform', etc. Default: 'ones'.
use_batch_statistics (bool): If true, use the mean value and variance value of current batch data, else use
the mean value and variance value of specified value. Default: True.
use_batch_statistics (bool): If true, use the mean value and variance value of current batch data. If false,
use the mean value and variance value of specified value. If None, training process will use the mean and
variance of current batch data and track the running mean and variance, eval process will use the running
mean and variance. Default: None.

Inputs:
- **input** (Tensor) - Tensor of shape :math:`(N, C_{in}, H_{in}, W_{in})`.
@@ -259,7 +265,7 @@ class BatchNorm1d(_BatchNorm):
beta_init='zeros',
moving_mean_init='zeros',
moving_var_init='ones',
use_batch_statistics=True):
use_batch_statistics=None):
super(BatchNorm1d, self).__init__(num_features,
eps,
momentum,
@@ -307,8 +313,10 @@ class BatchNorm2d(_BatchNorm):
moving_var_init (Union[Tensor, str, Initializer, numbers.Number]): Initializer for the moving variance.
The values of str refer to the function `initializer` including 'zeros', 'ones', 'xavier_uniform',
'he_uniform', etc. Default: 'ones'.
use_batch_statistics (bool): If true, use the mean value and variance value of current batch data, else use
the mean value and variance value of specified value. Default: True.
use_batch_statistics (bool): If true, use the mean value and variance value of current batch data. If false,
use the mean value and variance value of specified value. If None, training process will use the mean and
variance of current batch data and track the running mean and variance, eval process will use the running
mean and variance. Default: None.

Inputs:
- **input** (Tensor) - Tensor of shape :math:`(N, C_{in}, H_{in}, W_{in})`.
@@ -330,7 +338,7 @@ class BatchNorm2d(_BatchNorm):
beta_init='zeros',
moving_mean_init='zeros',
moving_var_init='ones',
use_batch_statistics=True):
use_batch_statistics=None):
super(BatchNorm2d, self).__init__(num_features,
eps,
momentum,
@@ -379,8 +387,10 @@ class GlobalBatchNorm(_BatchNorm):
moving_var_init (Union[Tensor, str, Initializer, numbers.Number]): Initializer for the moving variance.
The values of str refer to the function `initializer` including 'zeros', 'ones', 'xavier_uniform',
'he_uniform', etc. Default: 'ones'.
use_batch_statistics (bool): If true, use the mean value and variance value of current batch data, else use
the mean value and variance value of specified value. Default: True.
use_batch_statistics (bool): If true, use the mean value and variance value of current batch data. If false,
use the mean value and variance value of specified value. If None, training process will use the mean and
variance of current batch data and track the running mean and variance, eval process will use the running
mean and variance. Default: None.

Inputs:
- **input** (Tensor) - Tensor of shape :math:`(N, C_{in}, H_{in}, W_{in})`.
@@ -402,7 +412,7 @@ class GlobalBatchNorm(_BatchNorm):
beta_init='zeros',
moving_mean_init='zeros',
moving_var_init='ones',
use_batch_statistics=True,
use_batch_statistics=None,
device_num_each_group=1):
super(GlobalBatchNorm, self).__init__(num_features,
eps,


Loading…
Cancel
Save