|
|
|
@@ -366,15 +366,15 @@ class GlobalBatchNorm(_BatchNorm): |
|
|
|
use_batch_statistics=True, |
|
|
|
group=1): |
|
|
|
super(GlobalBatchNorm, self).__init__(num_features, |
|
|
|
eps, |
|
|
|
momentum, |
|
|
|
affine, |
|
|
|
gamma_init, |
|
|
|
beta_init, |
|
|
|
moving_mean_init, |
|
|
|
moving_var_init, |
|
|
|
use_batch_statistics, |
|
|
|
group) |
|
|
|
eps, |
|
|
|
momentum, |
|
|
|
affine, |
|
|
|
gamma_init, |
|
|
|
beta_init, |
|
|
|
moving_mean_init, |
|
|
|
moving_var_init, |
|
|
|
use_batch_statistics, |
|
|
|
group) |
|
|
|
self.group = check_int_positive(group) |
|
|
|
if self.group <= 1: |
|
|
|
raise ValueError("the number of group must be greater than 1.") |
|
|
|
|