|
|
|
@@ -116,15 +116,7 @@ class _BatchNorm(Cell): |
|
|
|
group_list = [list(i) for i in world_rank_list] |
|
|
|
return group_list |
|
|
|
|
|
|
|
def _shape_infer(self, x): |
|
|
|
"""global batch normalization shape and axes infer""" |
|
|
|
if len(self.shape(x)) == 4: |
|
|
|
axes = (0, 2, 3) |
|
|
|
re_shape = (1, self.num_features, 1, 1) |
|
|
|
else: |
|
|
|
axes = (0,) |
|
|
|
re_shape = (1, self.num_features) |
|
|
|
return axes, re_shape |
|
|
|
|
|
|
|
|
|
|
|
def _global_sync(self, x, axes, re_shape): |
|
|
|
"""calculate global batch normalization output""" |
|
|
|
@@ -150,7 +142,7 @@ class _BatchNorm(Cell): |
|
|
|
if self.training and self.use_batch_statistics: |
|
|
|
if self.is_ge_backend: |
|
|
|
if self.is_global: |
|
|
|
axes, re_shape = self._shape_infer(x) |
|
|
|
axes, re_shape = _shape_infer(F.shape(x), self.num_features) |
|
|
|
y = self._global_sync(x, axes, re_shape) |
|
|
|
else: |
|
|
|
y, batch_mean, batch_var, _, _ = \ |
|
|
|
@@ -189,6 +181,17 @@ def _channel_check(channel, num_channel): |
|
|
|
if channel != num_channel: |
|
|
|
raise ValueError("the input channel is not equal with num_channel") |
|
|
|
|
|
|
|
@constexpr |
|
|
|
def _shape_infer(x_shape, num_feature): |
|
|
|
"""global batch normalization shape and axes infer""" |
|
|
|
if len(x_shape) == 4: |
|
|
|
axes = (0, 2, 3) |
|
|
|
re_shape = (1, num_feature, 1, 1) |
|
|
|
else: |
|
|
|
axes = (0,) |
|
|
|
re_shape = (1, num_feature) |
|
|
|
return axes, re_shape |
|
|
|
|
|
|
|
class BatchNorm1d(_BatchNorm): |
|
|
|
r""" |
|
|
|
Batch normalization layer over a 2D input. |
|
|
|
|