Browse Source

fix issue about assignsub in globalbn

tags/v0.3.0-alpha
zhaojichen 6 years ago
parent
commit
224bb73049
1 changed files with 3 additions and 2 deletions
  1. +3
    -2
      mindspore/nn/layer/normalization.py

+ 3
- 2
mindspore/nn/layer/normalization.py View File

@@ -139,8 +139,9 @@ class _BatchNorm(Cell):
tmp_mean = self.mul_mean(mean_sub, self.cast(self.momentum, self.dtype(mean_sub)))
mean_sub2 = self.sub_var(self.reshape(self.moving_mean, re_shape), global_var)
tmp_variance = self.mul_var(mean_sub2, self.cast(self.momentum, self.dtype(mean_sub2)))
y = F.depend(y, self.assign_sub_mean(self.reshape(self.moving_mean, re_shape), tmp_mean))
y = F.depend(y, self.assign_sub_var(self.reshape(self.moving_variance, re_shape), tmp_variance))
y = F.depend(y, self.assign_sub_mean(self.moving_mean, self.reshape(tmp_mean, self.shape(self.moving_mean))))
y = F.depend(y, self.assign_sub_var(self.moving_variance,
self.reshape(tmp_variance, self.shape(self.moving_variance))))
return y

def construct(self, x):


Loading…
Cancel
Save