|
|
|
@@ -45,7 +45,7 @@ class _BatchNorm(Cell): |
|
|
|
moving_var_init='ones', |
|
|
|
use_batch_statistics=None, |
|
|
|
device_num_each_group=1, |
|
|
|
input_dims='1d'): |
|
|
|
input_dims='2d'): |
|
|
|
super(_BatchNorm, self).__init__() |
|
|
|
if num_features < 1: |
|
|
|
raise ValueError("num_features must be at least 1") |
|
|
|
@@ -151,6 +151,8 @@ class _BatchNorm(Cell): |
|
|
|
_shape_check(self.shape(x)) |
|
|
|
if self.input_dims == '1d': |
|
|
|
_shape_check_2d(self.shape(x)) |
|
|
|
if self.input_dims == 'both': |
|
|
|
_shape_check_2d_or_4d(self.shape(x)) |
|
|
|
if self.use_batch_statistics is None: |
|
|
|
flag = self.training |
|
|
|
else: |
|
|
|
@@ -211,7 +213,13 @@ def _shape_check_2d(input_shape): |
|
|
|
@constexpr |
|
|
|
def _shape_check(in_shape): |
|
|
|
if len(in_shape) != 4: |
|
|
|
raise ValueError("The input must has 4 dims") |
|
|
|
raise ValueError("The input must has 4 dims.") |
|
|
|
|
|
|
|
|
|
|
|
@constexpr |
|
|
|
def _shape_check_2d_or_4d(in_shape): |
|
|
|
if len(in_shape) != 2 and len(in_shape) != 4: |
|
|
|
raise ValueError("The input must has 2 dims or 4 dims.") |
|
|
|
|
|
|
|
|
|
|
|
@constexpr |
|
|
|
@@ -449,7 +457,8 @@ class GlobalBatchNorm(_BatchNorm): |
|
|
|
moving_mean_init, |
|
|
|
moving_var_init, |
|
|
|
use_batch_statistics, |
|
|
|
device_num_each_group) |
|
|
|
device_num_each_group, |
|
|
|
input_dims='both') |
|
|
|
self.group = check_int_positive(device_num_each_group) |
|
|
|
if self.group <= 1: |
|
|
|
raise ValueError("the number of group must be greater than 1.") |
|
|
|
|