|
|
|
@@ -101,6 +101,9 @@ class _BatchNorm(Cell): |
|
|
|
epsilon=self.eps, |
|
|
|
momentum=self.momentum) |
|
|
|
self.bn_infer = P.BatchNorm(is_training=False, epsilon=self.eps) |
|
|
|
self.enable_global_sync = self.is_global and (self.is_ge_backend or (self.is_graph_mode and self.is_ascend)) |
|
|
|
self.enable_default_train = self.is_graph_mode and not self.is_global and \ |
|
|
|
(self.is_ge_backend or self.is_ascend) |
|
|
|
|
|
|
|
data_parallel_strategy = ((1,), (1,)) |
|
|
|
data_parallel_strategy_one = ((1,), ()) |
|
|
|
@@ -147,51 +150,43 @@ class _BatchNorm(Cell): |
|
|
|
return y |
|
|
|
|
|
|
|
def construct(self, x): |
|
|
|
if self.input_dims == '2d': |
|
|
|
_shape_check(self.shape(x)) |
|
|
|
if self.input_dims == '1d': |
|
|
|
_shape_check_2d(self.shape(x)) |
|
|
|
if self.input_dims == 'both': |
|
|
|
_shape_check_2d_or_4d(self.shape(x)) |
|
|
|
_shape_check_bn(self.shape(x), self.input_dims) |
|
|
|
if self.use_batch_statistics is None: |
|
|
|
flag = self.training |
|
|
|
else: |
|
|
|
flag = self.use_batch_statistics |
|
|
|
|
|
|
|
if flag: |
|
|
|
if self.is_ge_backend and self.is_global: |
|
|
|
if self.enable_global_sync: |
|
|
|
axes, re_shape = _shape_infer(F.shape(x), self.num_features) |
|
|
|
y = self._global_sync(x, axes, re_shape) |
|
|
|
elif self.is_graph_mode and (self.is_ge_backend or self.is_ascend): |
|
|
|
if self.is_global: |
|
|
|
axes, re_shape = _shape_infer(F.shape(x), self.num_features) |
|
|
|
y = self._global_sync(x, axes, re_shape) |
|
|
|
else: |
|
|
|
y, batch_mean, batch_var, _, _ = \ |
|
|
|
self.bn_train(x, |
|
|
|
self.gamma, |
|
|
|
self.beta, |
|
|
|
None, |
|
|
|
None) |
|
|
|
|
|
|
|
mean_sub = self.sub_mean(self.moving_mean, batch_mean) |
|
|
|
temp_mean = self.mul_mean(mean_sub, self.momentum) |
|
|
|
mean_sub2 = self.sub_var(self.moving_variance, batch_var) |
|
|
|
temp_variance = self.mul_var(mean_sub2, self.momentum) |
|
|
|
y = F.depend(y, self.assign_sub_mean(self.moving_mean, temp_mean)) |
|
|
|
y = F.depend(y, self.assign_sub_var(self.moving_variance, temp_variance)) |
|
|
|
else: |
|
|
|
y = self.bn_train(x, |
|
|
|
self.gamma, |
|
|
|
self.beta, |
|
|
|
self.moving_mean, |
|
|
|
self.moving_variance)[0] |
|
|
|
else: |
|
|
|
y = self.bn_infer(x, |
|
|
|
self.gamma, |
|
|
|
self.beta, |
|
|
|
self.moving_mean, |
|
|
|
self.moving_variance)[0] |
|
|
|
return y |
|
|
|
return self._global_sync(x, axes, re_shape) |
|
|
|
|
|
|
|
if self.enable_default_train: |
|
|
|
y, batch_mean, batch_var, _, _ = self.bn_train(x, |
|
|
|
self.gamma, |
|
|
|
self.beta, |
|
|
|
None, |
|
|
|
None) |
|
|
|
|
|
|
|
mean_sub = self.sub_mean(self.moving_mean, batch_mean) |
|
|
|
temp_mean = self.mul_mean(mean_sub, self.momentum) |
|
|
|
mean_sub2 = self.sub_var(self.moving_variance, batch_var) |
|
|
|
temp_variance = self.mul_var(mean_sub2, self.momentum) |
|
|
|
y = F.depend(y, self.assign_sub_mean(self.moving_mean, temp_mean)) |
|
|
|
y = F.depend(y, self.assign_sub_var(self.moving_variance, temp_variance)) |
|
|
|
return y |
|
|
|
|
|
|
|
return self.bn_train(x, |
|
|
|
self.gamma, |
|
|
|
self.beta, |
|
|
|
self.moving_mean, |
|
|
|
self.moving_variance)[0] |
|
|
|
|
|
|
|
return self.bn_infer(x, |
|
|
|
self.gamma, |
|
|
|
self.beta, |
|
|
|
self.moving_mean, |
|
|
|
self.moving_variance)[0] |
|
|
|
|
|
|
|
def extend_repr(self): |
|
|
|
return 'num_features={}, eps={}, momentum={}, gamma={}, beta={}, moving_mean={}, moving_variance={}'.format( |
|
|
|
@@ -204,12 +199,6 @@ def _channel_check(channel, num_channel): |
|
|
|
raise ValueError("the input channel is not equal with num_channel") |
|
|
|
|
|
|
|
|
|
|
|
@constexpr |
|
|
|
def _shape_check_2d(input_shape): |
|
|
|
if len(input_shape) != 2: |
|
|
|
raise ValueError("The input must has 2 dims.") |
|
|
|
|
|
|
|
|
|
|
|
@constexpr |
|
|
|
def _shape_check(in_shape): |
|
|
|
if len(in_shape) != 4: |
|
|
|
@@ -217,8 +206,13 @@ def _shape_check(in_shape): |
|
|
|
|
|
|
|
|
|
|
|
@constexpr |
|
|
|
def _shape_check_2d_or_4d(in_shape): |
|
|
|
if len(in_shape) != 2 and len(in_shape) != 4: |
|
|
|
def _shape_check_bn(in_shape, in_dims): |
|
|
|
dim = len(in_shape) |
|
|
|
if in_dims == '1d' and dim != 2: |
|
|
|
raise ValueError("The input must has 2 dims.") |
|
|
|
if in_dims == '2d' and dim != 4: |
|
|
|
raise ValueError("The input must has 4 dims.") |
|
|
|
if in_dims == 'both' and dim != 2 and dim != 4: |
|
|
|
raise ValueError("The input must has 2 dims or 4 dims.") |
|
|
|
|
|
|
|
|
|
|
|
|