Browse Source

fix globalbatchnorm bug

tags/v0.3.0-alpha
zhaojichen 5 years ago
parent
commit
8ca1f87a49
1 changed files with 27 additions and 36 deletions
  1. +27
    -36
      mindspore/nn/layer/normalization.py

+ 27
- 36
mindspore/nn/layer/normalization.py View File

@@ -116,53 +116,44 @@ class _BatchNorm(Cell):
group_list = [list(i) for i in world_rank_list] group_list = [list(i) for i in world_rank_list]
return group_list return group_list


def _global_sync(self, x):
"""calculate global batch normalization output"""
def _shape_infer(self, x):
"""global batch normalization shape and axes infer"""
if len(self.shape(x)) == 4: if len(self.shape(x)) == 4:
axes = (0, 2, 3)
axes = (0,2,3)
re_shape = (1, self.num_features, 1, 1) re_shape = (1, self.num_features, 1, 1)
x_mean = self.reduce_mean(x, axes)
x_mean_square = self.reduce_mean(self.square(x), axes)
global_batch_mean = self.all_reduce(x_mean) / self.group
global_batch_mean_square = self.all_reduce(x_mean_square) / self.group
global_mean = global_batch_mean
global_var = global_batch_mean_square - self.square(global_mean)
var_sqrt = self.sqrt(global_var + self.eps)
mean_first = (x - global_mean) / var_sqrt
y = mean_first * self.reshape(self.gamma, re_shape) + self.reshape(self.beta, re_shape)

mean_sub = self.sub_mean(self.reshape(self.moving_mean, re_shape), global_mean)
tmp_mean = self.mul_mean(mean_sub, self.cast(self.momentum, self.dtype(mean_sub)))
mean_sub2 = self.sub_var(self.reshape(self.moving_mean, re_shape), global_var)
tmp_variance = self.mul_var(mean_sub2, self.cast(self.momentum, self.dtype(mean_sub2)))
y = F.depend(y, self.assign_sub_mean(self.reshape(self.moving_mean, re_shape), tmp_mean))
y = F.depend(y, self.assign_sub_var(self.reshape(self.moving_variance, re_shape), tmp_variance))
else: else:
axes = (0,) axes = (0,)
re_shape = (1, self.num_features) re_shape = (1, self.num_features)
x_mean = self.reduce_mean(x, axes)
x_mean_square = self.reduce_mean(self.square(x), axes)
global_batch_mean = self.all_reduce(x_mean) / self.group
global_batch_mean_square = self.all_reduce(x_mean_square) / self.group
global_mean = global_batch_mean
global_var = global_batch_mean_square - self.square(global_mean)
var_sqrt = self.sqrt(global_var + self.eps)
mean_first = (x - global_mean) / var_sqrt
y = mean_first * self.gamma + self.beta

mean_sub = self.sub_mean(self.moving_mean, global_mean)
temp_mean = self.mul_mean(mean_sub, self.cast(self.momentum, self.dtype(mean_sub)))
mean_sub2 = self.sub_var(self.moving_variance, global_var)
temp_variance = self.mul_var(mean_sub2, self.cast(self.momentum, self.dtype(mean_sub2)))
y = F.depend(y, self.assign_sub_mean(self.reshape(self.moving_mean, re_shape), temp_mean))
y = F.depend(y, self.assign_sub_var(self.reshape(self.moving_variance, re_shape), temp_variance))
return axes, re_shape

def _global_sync(self, x, axes, re_shape):
"""calculate global batch normalization output"""
axes = (0, 2, 3)
re_shape = (1, self.num_features, 1, 1)
x_mean = self.reduce_mean(x, axes)
x_mean_square = self.reduce_mean(self.square(x), axes)
global_batch_mean = self.all_reduce(x_mean) / self.group
global_batch_mean_square = self.all_reduce(x_mean_square) / self.group
global_mean = global_batch_mean
global_var = global_batch_mean_square - self.square(global_mean)
var_sqrt = self.sqrt(global_var + self.eps)
mean_first = (x - global_mean) / var_sqrt
y = mean_first * self.reshape(self.gamma, re_shape) + self.reshape(self.beta, re_shape)

mean_sub = self.sub_mean(self.reshape(self.moving_mean, re_shape), global_mean)
tmp_mean = self.mul_mean(mean_sub, self.cast(self.momentum, self.dtype(mean_sub)))
mean_sub2 = self.sub_var(self.reshape(self.moving_mean, re_shape), global_var)
tmp_variance = self.mul_var(mean_sub2, self.cast(self.momentum, self.dtype(mean_sub2)))
y = F.depend(y, self.assign_sub_mean(self.reshape(self.moving_mean, re_shape), tmp_mean))
y = F.depend(y, self.assign_sub_var(self.reshape(self.moving_variance, re_shape), tmp_variance))
return y return y


def construct(self, x): def construct(self, x):
if self.training and self.use_batch_statistics: if self.training and self.use_batch_statistics:
if self.is_ge_backend: if self.is_ge_backend:
if self.is_global: if self.is_global:
y = self._global_sync(x)
axes, re_shape = self._shape_infer(x)
y = self._global_sync(x, axes, re_shape)
else: else:
y, batch_mean, batch_var, _, _ = \ y, batch_mean, batch_var, _, _ = \
self.bn_train(x, self.bn_train(x,


Loading…
Cancel
Save