Browse Source

Optimization for BatchNorm

tags/v0.7.0-beta
wuyongkang 5 years ago
parent
commit
a8fd71c785
1 changed files with 41 additions and 47 deletions
  1. +41
    -47
      mindspore/nn/layer/normalization.py

+ 41
- 47
mindspore/nn/layer/normalization.py View File

@@ -101,6 +101,9 @@ class _BatchNorm(Cell):
epsilon=self.eps,
momentum=self.momentum)
self.bn_infer = P.BatchNorm(is_training=False, epsilon=self.eps)
self.enable_global_sync = self.is_global and (self.is_ge_backend or (self.is_graph_mode and self.is_ascend))
self.enable_default_train = self.is_graph_mode and not self.is_global and \
(self.is_ge_backend or self.is_ascend)

data_parallel_strategy = ((1,), (1,))
data_parallel_strategy_one = ((1,), ())
@@ -147,51 +150,43 @@ class _BatchNorm(Cell):
return y

def construct(self, x):
if self.input_dims == '2d':
_shape_check(self.shape(x))
if self.input_dims == '1d':
_shape_check_2d(self.shape(x))
if self.input_dims == 'both':
_shape_check_2d_or_4d(self.shape(x))
_shape_check_bn(self.shape(x), self.input_dims)
if self.use_batch_statistics is None:
flag = self.training
else:
flag = self.use_batch_statistics

if flag:
if self.is_ge_backend and self.is_global:
if self.enable_global_sync:
axes, re_shape = _shape_infer(F.shape(x), self.num_features)
y = self._global_sync(x, axes, re_shape)
elif self.is_graph_mode and (self.is_ge_backend or self.is_ascend):
if self.is_global:
axes, re_shape = _shape_infer(F.shape(x), self.num_features)
y = self._global_sync(x, axes, re_shape)
else:
y, batch_mean, batch_var, _, _ = \
self.bn_train(x,
self.gamma,
self.beta,
None,
None)

mean_sub = self.sub_mean(self.moving_mean, batch_mean)
temp_mean = self.mul_mean(mean_sub, self.momentum)
mean_sub2 = self.sub_var(self.moving_variance, batch_var)
temp_variance = self.mul_var(mean_sub2, self.momentum)
y = F.depend(y, self.assign_sub_mean(self.moving_mean, temp_mean))
y = F.depend(y, self.assign_sub_var(self.moving_variance, temp_variance))
else:
y = self.bn_train(x,
self.gamma,
self.beta,
self.moving_mean,
self.moving_variance)[0]
else:
y = self.bn_infer(x,
self.gamma,
self.beta,
self.moving_mean,
self.moving_variance)[0]
return y
return self._global_sync(x, axes, re_shape)

if self.enable_default_train:
y, batch_mean, batch_var, _, _ = self.bn_train(x,
self.gamma,
self.beta,
None,
None)

mean_sub = self.sub_mean(self.moving_mean, batch_mean)
temp_mean = self.mul_mean(mean_sub, self.momentum)
mean_sub2 = self.sub_var(self.moving_variance, batch_var)
temp_variance = self.mul_var(mean_sub2, self.momentum)
y = F.depend(y, self.assign_sub_mean(self.moving_mean, temp_mean))
y = F.depend(y, self.assign_sub_var(self.moving_variance, temp_variance))
return y

return self.bn_train(x,
self.gamma,
self.beta,
self.moving_mean,
self.moving_variance)[0]

return self.bn_infer(x,
self.gamma,
self.beta,
self.moving_mean,
self.moving_variance)[0]

def extend_repr(self):
return 'num_features={}, eps={}, momentum={}, gamma={}, beta={}, moving_mean={}, moving_variance={}'.format(
@@ -204,12 +199,6 @@ def _channel_check(channel, num_channel):
raise ValueError("the input channel is not equal with num_channel")


@constexpr
def _shape_check_2d(input_shape):
if len(input_shape) != 2:
raise ValueError("The input must has 2 dims.")


@constexpr
def _shape_check(in_shape):
if len(in_shape) != 4:
@@ -217,8 +206,13 @@ def _shape_check(in_shape):


@constexpr
def _shape_check_2d_or_4d(in_shape):
if len(in_shape) != 2 and len(in_shape) != 4:
def _shape_check_bn(in_shape, in_dims):
dim = len(in_shape)
if in_dims == '1d' and dim != 2:
raise ValueError("The input must has 2 dims.")
if in_dims == '2d' and dim != 4:
raise ValueError("The input must has 4 dims.")
if in_dims == 'both' and dim != 2 and dim != 4:
raise ValueError("The input must has 2 dims or 4 dims.")




Loading…
Cancel
Save