|
|
|
@@ -156,19 +156,23 @@ class _BatchNorm(Cell): |
|
|
|
axes, re_shape = _shape_infer(F.shape(x), self.num_features) |
|
|
|
y = self._global_sync(x, axes, re_shape) |
|
|
|
elif self.is_graph_mode and (self.is_ge_backend or self.is_ascend): |
|
|
|
y, batch_mean, batch_var, _, _ = \ |
|
|
|
self.bn_train(x, |
|
|
|
self.gamma, |
|
|
|
self.beta, |
|
|
|
None, |
|
|
|
None) |
|
|
|
|
|
|
|
mean_sub = self.sub_mean(self.moving_mean, batch_mean) |
|
|
|
temp_mean = self.mul_mean(mean_sub, self.momentum) |
|
|
|
mean_sub2 = self.sub_var(self.moving_variance, batch_var) |
|
|
|
temp_variance = self.mul_var(mean_sub2, self.momentum) |
|
|
|
y = F.depend(y, self.assign_sub_mean(self.moving_mean, temp_mean)) |
|
|
|
y = F.depend(y, self.assign_sub_var(self.moving_variance, temp_variance)) |
|
|
|
if self.is_global: |
|
|
|
axes, re_shape = _shape_infer(F.shape(x), self.num_features) |
|
|
|
y = self._global_sync(x, axes, re_shape) |
|
|
|
else: |
|
|
|
y, batch_mean, batch_var, _, _ = \ |
|
|
|
self.bn_train(x, |
|
|
|
self.gamma, |
|
|
|
self.beta, |
|
|
|
None, |
|
|
|
None) |
|
|
|
|
|
|
|
mean_sub = self.sub_mean(self.moving_mean, batch_mean) |
|
|
|
temp_mean = self.mul_mean(mean_sub, self.momentum) |
|
|
|
mean_sub2 = self.sub_var(self.moving_variance, batch_var) |
|
|
|
temp_variance = self.mul_var(mean_sub2, self.momentum) |
|
|
|
y = F.depend(y, self.assign_sub_mean(self.moving_mean, temp_mean)) |
|
|
|
y = F.depend(y, self.assign_sub_var(self.moving_variance, temp_variance)) |
|
|
|
else: |
|
|
|
y = self.bn_train(x, |
|
|
|
self.gamma, |
|
|
|
|