| @@ -116,15 +116,7 @@ class _BatchNorm(Cell): | |||||
| group_list = [list(i) for i in world_rank_list] | group_list = [list(i) for i in world_rank_list] | ||||
| return group_list | return group_list | ||||
| def _shape_infer(self, x): | |||||
| """global batch normalization shape and axes infer""" | |||||
| if len(self.shape(x)) == 4: | |||||
| axes = (0, 2, 3) | |||||
| re_shape = (1, self.num_features, 1, 1) | |||||
| else: | |||||
| axes = (0,) | |||||
| re_shape = (1, self.num_features) | |||||
| return axes, re_shape | |||||
| def _global_sync(self, x, axes, re_shape): | def _global_sync(self, x, axes, re_shape): | ||||
| """calculate global batch normalization output""" | """calculate global batch normalization output""" | ||||
| @@ -150,7 +142,7 @@ class _BatchNorm(Cell): | |||||
| if self.training and self.use_batch_statistics: | if self.training and self.use_batch_statistics: | ||||
| if self.is_ge_backend: | if self.is_ge_backend: | ||||
| if self.is_global: | if self.is_global: | ||||
| axes, re_shape = self._shape_infer(x) | |||||
| axes, re_shape = _shape_infer(F.shape(x), self.num_features) | |||||
| y = self._global_sync(x, axes, re_shape) | y = self._global_sync(x, axes, re_shape) | ||||
| else: | else: | ||||
| y, batch_mean, batch_var, _, _ = \ | y, batch_mean, batch_var, _, _ = \ | ||||
| @@ -189,6 +181,17 @@ def _channel_check(channel, num_channel): | |||||
| if channel != num_channel: | if channel != num_channel: | ||||
| raise ValueError("the input channel is not equal with num_channel") | raise ValueError("the input channel is not equal with num_channel") | ||||
| @constexpr | |||||
| def _shape_infer(x_shape, num_feature): | |||||
| """global batch normalization shape and axes infer""" | |||||
| if len(x_shape) == 4: | |||||
| axes = (0, 2, 3) | |||||
| re_shape = (1, num_feature, 1, 1) | |||||
| else: | |||||
| axes = (0,) | |||||
| re_shape = (1, num_feature) | |||||
| return axes, re_shape | |||||
| class BatchNorm1d(_BatchNorm): | class BatchNorm1d(_BatchNorm): | ||||
| r""" | r""" | ||||
| Batch normalization layer over a 2D input. | Batch normalization layer over a 2D input. | ||||