|
|
|
@@ -44,7 +44,8 @@ class _BatchNorm(Cell): |
|
|
|
moving_mean_init='zeros', |
|
|
|
moving_var_init='ones', |
|
|
|
use_batch_statistics=None, |
|
|
|
device_num_each_group=1): |
|
|
|
device_num_each_group=1, |
|
|
|
input_dims='1d'): |
|
|
|
super(_BatchNorm, self).__init__() |
|
|
|
if num_features < 1: |
|
|
|
raise ValueError("num_features must be at least 1") |
|
|
|
@@ -55,6 +56,7 @@ class _BatchNorm(Cell): |
|
|
|
self.use_batch_statistics = use_batch_statistics |
|
|
|
self.num_features = num_features |
|
|
|
self.eps = eps |
|
|
|
self.input_dims = input_dims |
|
|
|
self.moving_mean = Parameter(initializer( |
|
|
|
moving_mean_init, num_features), name="mean", requires_grad=False) |
|
|
|
self.moving_variance = Parameter(initializer( |
|
|
|
@@ -145,6 +147,8 @@ class _BatchNorm(Cell): |
|
|
|
return y |
|
|
|
|
|
|
|
def construct(self, x): |
|
|
|
if self.input_dims == '2d': |
|
|
|
_shape_check(self.shape(x)) |
|
|
|
if self.use_batch_statistics is None: |
|
|
|
flag = self.training |
|
|
|
else: |
|
|
|
@@ -253,10 +257,10 @@ class BatchNorm1d(_BatchNorm): |
|
|
|
mean and variance. Default: None. |
|
|
|
|
|
|
|
Inputs: |
|
|
|
- **input** (Tensor) - Tensor of shape :math:`(N, C_{in}, H_{in}, W_{in})`. |
|
|
|
- **input** (Tensor) - Tensor of shape :math:`(N, C_{in})`. |
|
|
|
|
|
|
|
Outputs: |
|
|
|
Tensor, the normalized, scaled, offset tensor, of shape :math:`(N, C_{out}, H_{out}, W_{out})`. |
|
|
|
Tensor, the normalized, scaled, offset tensor, of shape :math:`(N, C_{out})`. |
|
|
|
|
|
|
|
Examples: |
|
|
|
>>> net = nn.BatchNorm1d(num_features=16) |
|
|
|
@@ -282,7 +286,8 @@ class BatchNorm1d(_BatchNorm): |
|
|
|
beta_init, |
|
|
|
moving_mean_init, |
|
|
|
moving_var_init, |
|
|
|
use_batch_statistics) |
|
|
|
use_batch_statistics, |
|
|
|
input_dims='1d') |
|
|
|
|
|
|
|
def _check_data_dim(self, x): |
|
|
|
if x.dim() != 2: |
|
|
|
@@ -357,7 +362,8 @@ class BatchNorm2d(_BatchNorm): |
|
|
|
beta_init, |
|
|
|
moving_mean_init, |
|
|
|
moving_var_init, |
|
|
|
use_batch_statistics) |
|
|
|
use_batch_statistics, |
|
|
|
input_dims='2d') |
|
|
|
|
|
|
|
def _check_data_dim(self, x): |
|
|
|
if x.dim() != 4: |
|
|
|
|