| @@ -101,6 +101,9 @@ class _BatchNorm(Cell): | |||||
| epsilon=self.eps, | epsilon=self.eps, | ||||
| momentum=self.momentum) | momentum=self.momentum) | ||||
| self.bn_infer = P.BatchNorm(is_training=False, epsilon=self.eps) | self.bn_infer = P.BatchNorm(is_training=False, epsilon=self.eps) | ||||
| self.enable_global_sync = self.is_global and (self.is_ge_backend or (self.is_graph_mode and self.is_ascend)) | |||||
| self.enable_default_train = self.is_graph_mode and not self.is_global and \ | |||||
| (self.is_ge_backend or self.is_ascend) | |||||
| data_parallel_strategy = ((1,), (1,)) | data_parallel_strategy = ((1,), (1,)) | ||||
| data_parallel_strategy_one = ((1,), ()) | data_parallel_strategy_one = ((1,), ()) | ||||
| @@ -147,51 +150,43 @@ class _BatchNorm(Cell): | |||||
| return y | return y | ||||
| def construct(self, x): | def construct(self, x): | ||||
| if self.input_dims == '2d': | |||||
| _shape_check(self.shape(x)) | |||||
| if self.input_dims == '1d': | |||||
| _shape_check_2d(self.shape(x)) | |||||
| if self.input_dims == 'both': | |||||
| _shape_check_2d_or_4d(self.shape(x)) | |||||
| _shape_check_bn(self.shape(x), self.input_dims) | |||||
| if self.use_batch_statistics is None: | if self.use_batch_statistics is None: | ||||
| flag = self.training | flag = self.training | ||||
| else: | else: | ||||
| flag = self.use_batch_statistics | flag = self.use_batch_statistics | ||||
| if flag: | if flag: | ||||
| if self.is_ge_backend and self.is_global: | |||||
| if self.enable_global_sync: | |||||
| axes, re_shape = _shape_infer(F.shape(x), self.num_features) | axes, re_shape = _shape_infer(F.shape(x), self.num_features) | ||||
| y = self._global_sync(x, axes, re_shape) | |||||
| elif self.is_graph_mode and (self.is_ge_backend or self.is_ascend): | |||||
| if self.is_global: | |||||
| axes, re_shape = _shape_infer(F.shape(x), self.num_features) | |||||
| y = self._global_sync(x, axes, re_shape) | |||||
| else: | |||||
| y, batch_mean, batch_var, _, _ = \ | |||||
| self.bn_train(x, | |||||
| self.gamma, | |||||
| self.beta, | |||||
| None, | |||||
| None) | |||||
| mean_sub = self.sub_mean(self.moving_mean, batch_mean) | |||||
| temp_mean = self.mul_mean(mean_sub, self.momentum) | |||||
| mean_sub2 = self.sub_var(self.moving_variance, batch_var) | |||||
| temp_variance = self.mul_var(mean_sub2, self.momentum) | |||||
| y = F.depend(y, self.assign_sub_mean(self.moving_mean, temp_mean)) | |||||
| y = F.depend(y, self.assign_sub_var(self.moving_variance, temp_variance)) | |||||
| else: | |||||
| y = self.bn_train(x, | |||||
| self.gamma, | |||||
| self.beta, | |||||
| self.moving_mean, | |||||
| self.moving_variance)[0] | |||||
| else: | |||||
| y = self.bn_infer(x, | |||||
| self.gamma, | |||||
| self.beta, | |||||
| self.moving_mean, | |||||
| self.moving_variance)[0] | |||||
| return y | |||||
| return self._global_sync(x, axes, re_shape) | |||||
| if self.enable_default_train: | |||||
| y, batch_mean, batch_var, _, _ = self.bn_train(x, | |||||
| self.gamma, | |||||
| self.beta, | |||||
| None, | |||||
| None) | |||||
| mean_sub = self.sub_mean(self.moving_mean, batch_mean) | |||||
| temp_mean = self.mul_mean(mean_sub, self.momentum) | |||||
| mean_sub2 = self.sub_var(self.moving_variance, batch_var) | |||||
| temp_variance = self.mul_var(mean_sub2, self.momentum) | |||||
| y = F.depend(y, self.assign_sub_mean(self.moving_mean, temp_mean)) | |||||
| y = F.depend(y, self.assign_sub_var(self.moving_variance, temp_variance)) | |||||
| return y | |||||
| return self.bn_train(x, | |||||
| self.gamma, | |||||
| self.beta, | |||||
| self.moving_mean, | |||||
| self.moving_variance)[0] | |||||
| return self.bn_infer(x, | |||||
| self.gamma, | |||||
| self.beta, | |||||
| self.moving_mean, | |||||
| self.moving_variance)[0] | |||||
| def extend_repr(self): | def extend_repr(self): | ||||
| return 'num_features={}, eps={}, momentum={}, gamma={}, beta={}, moving_mean={}, moving_variance={}'.format( | return 'num_features={}, eps={}, momentum={}, gamma={}, beta={}, moving_mean={}, moving_variance={}'.format( | ||||
| @@ -204,12 +199,6 @@ def _channel_check(channel, num_channel): | |||||
| raise ValueError("the input channel is not equal with num_channel") | raise ValueError("the input channel is not equal with num_channel") | ||||
| @constexpr | |||||
| def _shape_check_2d(input_shape): | |||||
| if len(input_shape) != 2: | |||||
| raise ValueError("The input must has 2 dims.") | |||||
| @constexpr | @constexpr | ||||
| def _shape_check(in_shape): | def _shape_check(in_shape): | ||||
| if len(in_shape) != 4: | if len(in_shape) != 4: | ||||
| @@ -217,8 +206,13 @@ def _shape_check(in_shape): | |||||
| @constexpr | @constexpr | ||||
| def _shape_check_2d_or_4d(in_shape): | |||||
| if len(in_shape) != 2 and len(in_shape) != 4: | |||||
| def _shape_check_bn(in_shape, in_dims): | |||||
| dim = len(in_shape) | |||||
| if in_dims == '1d' and dim != 2: | |||||
| raise ValueError("The input must has 2 dims.") | |||||
| if in_dims == '2d' and dim != 4: | |||||
| raise ValueError("The input must has 4 dims.") | |||||
| if in_dims == 'both' and dim != 2 and dim != 4: | |||||
| raise ValueError("The input must has 2 dims or 4 dims.") | raise ValueError("The input must has 2 dims or 4 dims.") | ||||