Browse Source

!1136 fix bn train&eval loss problem

Merge pull request !1136 from JichenZhao/bn_train_eval_loss_issue
tags/v0.5.0-beta
mindspore-ci-bot Gitee 5 years ago
parent
commit
93a5201210
1 changed files with 21 additions and 11 deletions
  1. +21
    -11
      mindspore/nn/layer/normalization.py

+ 21
- 11
mindspore/nn/layer/normalization.py View File

@@ -43,7 +43,7 @@ class _BatchNorm(Cell):
beta_init='zeros',
moving_mean_init='zeros',
moving_var_init='ones',
use_batch_statistics=True,
use_batch_statistics=None,
device_num_each_group=1):
super(_BatchNorm, self).__init__()
if num_features < 1:
@@ -147,7 +147,11 @@ class _BatchNorm(Cell):
return y

def construct(self, x):
if self.training and self.use_batch_statistics:
if self.use_batch_statistics is None:
flag = self.training
else:
flag = self.use_batch_statistics
if flag:
if self.is_ge_backend and self.is_global:
axes, re_shape = _shape_infer(F.shape(x), self.num_features)
y = self._global_sync(x, axes, re_shape)
@@ -236,8 +240,10 @@ class BatchNorm1d(_BatchNorm):
moving_var_init (Union[Tensor, str, Initializer, numbers.Number]): Initializer for the moving variance.
The values of str refer to the function `initializer` including 'zeros', 'ones', 'xavier_uniform',
'he_uniform', etc. Default: 'ones'.
use_batch_statistics (bool): If true, use the mean value and variance value of current batch data, else use
the mean value and variance value of specified value. Default: True.
use_batch_statistics (bool): If true, use the mean value and variance value of current batch data. If false,
use the mean value and variance value of specified value. If None, training process will use the mean and
variance of current batch data and track the running mean and variance, eval process will use the running
mean and variance. Default: None.

Inputs:
- **input** (Tensor) - Tensor of shape :math:`(N, C_{in}, H_{in}, W_{in})`.
@@ -259,7 +265,7 @@ class BatchNorm1d(_BatchNorm):
beta_init='zeros',
moving_mean_init='zeros',
moving_var_init='ones',
use_batch_statistics=True):
use_batch_statistics=None):
super(BatchNorm1d, self).__init__(num_features,
eps,
momentum,
@@ -307,8 +313,10 @@ class BatchNorm2d(_BatchNorm):
moving_var_init (Union[Tensor, str, Initializer, numbers.Number]): Initializer for the moving variance.
The values of str refer to the function `initializer` including 'zeros', 'ones', 'xavier_uniform',
'he_uniform', etc. Default: 'ones'.
use_batch_statistics (bool): If true, use the mean value and variance value of current batch data, else use
the mean value and variance value of specified value. Default: True.
use_batch_statistics (bool): If true, use the mean value and variance value of current batch data. If false,
use the mean value and variance value of specified value. If None, training process will use the mean and
variance of current batch data and track the running mean and variance, eval process will use the running
mean and variance. Default: None.

Inputs:
- **input** (Tensor) - Tensor of shape :math:`(N, C_{in}, H_{in}, W_{in})`.
@@ -330,7 +338,7 @@ class BatchNorm2d(_BatchNorm):
beta_init='zeros',
moving_mean_init='zeros',
moving_var_init='ones',
use_batch_statistics=True):
use_batch_statistics=None):
super(BatchNorm2d, self).__init__(num_features,
eps,
momentum,
@@ -379,8 +387,10 @@ class GlobalBatchNorm(_BatchNorm):
moving_var_init (Union[Tensor, str, Initializer, numbers.Number]): Initializer for the moving variance.
The values of str refer to the function `initializer` including 'zeros', 'ones', 'xavier_uniform',
'he_uniform', etc. Default: 'ones'.
use_batch_statistics (bool): If true, use the mean value and variance value of current batch data, else use
the mean value and variance value of specified value. Default: True.
use_batch_statistics (bool): If true, use the mean value and variance value of current batch data. If false,
use the mean value and variance value of specified value. If None, training process will use the mean and
variance of current batch data and track the running mean and variance, eval process will use the running
mean and variance. Default: None.

Inputs:
- **input** (Tensor) - Tensor of shape :math:`(N, C_{in}, H_{in}, W_{in})`.
@@ -402,7 +412,7 @@ class GlobalBatchNorm(_BatchNorm):
beta_init='zeros',
moving_mean_init='zeros',
moving_var_init='ones',
use_batch_statistics=True,
use_batch_statistics=None,
device_num_each_group=1):
super(GlobalBatchNorm, self).__init__(num_features,
eps,


Loading…
Cancel
Save