|
|
@@ -33,7 +33,6 @@ class _BatchNorm(Cell): |
|
|
@cell_attr_register |
|
|
@cell_attr_register |
|
|
def __init__(self, |
|
|
def __init__(self, |
|
|
num_features, |
|
|
num_features, |
|
|
group=1, |
|
|
|
|
|
eps=1e-5, |
|
|
eps=1e-5, |
|
|
momentum=0.9, |
|
|
momentum=0.9, |
|
|
affine=True, |
|
|
affine=True, |
|
|
@@ -41,7 +40,8 @@ class _BatchNorm(Cell): |
|
|
beta_init='zeros', |
|
|
beta_init='zeros', |
|
|
moving_mean_init='zeros', |
|
|
moving_mean_init='zeros', |
|
|
moving_var_init='ones', |
|
|
moving_var_init='ones', |
|
|
use_batch_statistics=True): |
|
|
|
|
|
|
|
|
use_batch_statistics=True, |
|
|
|
|
|
group=1): |
|
|
super(_BatchNorm, self).__init__() |
|
|
super(_BatchNorm, self).__init__() |
|
|
if num_features < 1: |
|
|
if num_features < 1: |
|
|
raise ValueError("num_features must be at least 1") |
|
|
raise ValueError("num_features must be at least 1") |
|
|
@@ -214,6 +214,25 @@ class BatchNorm1d(_BatchNorm): |
|
|
>>> input = Tensor(np.random.randint(0, 255, [3, 16]), mindspore.float32) |
|
|
>>> input = Tensor(np.random.randint(0, 255, [3, 16]), mindspore.float32) |
|
|
>>> net(input) |
|
|
>>> net(input) |
|
|
""" |
|
|
""" |
|
|
|
|
|
def __init__(self, |
|
|
|
|
|
num_features, |
|
|
|
|
|
eps=1e-5, |
|
|
|
|
|
momentum=0.9, |
|
|
|
|
|
affine=True, |
|
|
|
|
|
gamma_init='ones', |
|
|
|
|
|
beta_init='zeros', |
|
|
|
|
|
moving_mean_init='zeros', |
|
|
|
|
|
moving_var_init='ones', |
|
|
|
|
|
use_batch_statistics=True): |
|
|
|
|
|
super(BatchNorm1d, self).__init__(num_features, |
|
|
|
|
|
eps, |
|
|
|
|
|
momentum, |
|
|
|
|
|
affine, |
|
|
|
|
|
gamma_init, |
|
|
|
|
|
beta_init, |
|
|
|
|
|
moving_mean_init, |
|
|
|
|
|
moving_var_init, |
|
|
|
|
|
use_batch_statistics) |
|
|
def _check_data_dim(self, x): |
|
|
def _check_data_dim(self, x): |
|
|
if x.dim() != 2: |
|
|
if x.dim() != 2: |
|
|
pass |
|
|
pass |
|
|
@@ -266,6 +285,25 @@ class BatchNorm2d(_BatchNorm): |
|
|
>>> input = Tensor(np.random.randint(0, 255, [1, 3, 224, 224]), mindspore.float32) |
|
|
>>> input = Tensor(np.random.randint(0, 255, [1, 3, 224, 224]), mindspore.float32) |
|
|
>>> net(input) |
|
|
>>> net(input) |
|
|
""" |
|
|
""" |
|
|
|
|
|
def __init__(self, |
|
|
|
|
|
num_features, |
|
|
|
|
|
eps=1e-5, |
|
|
|
|
|
momentum=0.9, |
|
|
|
|
|
affine=True, |
|
|
|
|
|
gamma_init='ones', |
|
|
|
|
|
beta_init='zeros', |
|
|
|
|
|
moving_mean_init='zeros', |
|
|
|
|
|
moving_var_init='ones', |
|
|
|
|
|
use_batch_statistics=True): |
|
|
|
|
|
super(BatchNorm2d, self).__init__(num_features, |
|
|
|
|
|
eps, |
|
|
|
|
|
momentum, |
|
|
|
|
|
affine, |
|
|
|
|
|
gamma_init, |
|
|
|
|
|
beta_init, |
|
|
|
|
|
moving_mean_init, |
|
|
|
|
|
moving_var_init, |
|
|
|
|
|
use_batch_statistics) |
|
|
def _check_data_dim(self, x): |
|
|
def _check_data_dim(self, x): |
|
|
if x.dim() != 4: |
|
|
if x.dim() != 4: |
|
|
pass |
|
|
pass |
|
|
@@ -316,6 +354,30 @@ class GlobalBatchNorm(_BatchNorm): |
|
|
>>> input = Tensor(np.random.randint(0, 255, [1, 3, 224, 224]), mindspore.float32) |
|
|
>>> input = Tensor(np.random.randint(0, 255, [1, 3, 224, 224]), mindspore.float32) |
|
|
>>> global_bn_op(input) |
|
|
>>> global_bn_op(input) |
|
|
""" |
|
|
""" |
|
|
|
|
|
def __init__(self, |
|
|
|
|
|
num_features, |
|
|
|
|
|
eps=1e-5, |
|
|
|
|
|
momentum=0.9, |
|
|
|
|
|
affine=True, |
|
|
|
|
|
gamma_init='ones', |
|
|
|
|
|
beta_init='zeros', |
|
|
|
|
|
moving_mean_init='zeros', |
|
|
|
|
|
moving_var_init='ones', |
|
|
|
|
|
use_batch_statistics=True, |
|
|
|
|
|
group=1): |
|
|
|
|
|
super(GlobalBatchNorm, self).__init__(num_features, |
|
|
|
|
|
eps, |
|
|
|
|
|
momentum, |
|
|
|
|
|
affine, |
|
|
|
|
|
gamma_init, |
|
|
|
|
|
beta_init, |
|
|
|
|
|
moving_mean_init, |
|
|
|
|
|
moving_var_init, |
|
|
|
|
|
use_batch_statistics, |
|
|
|
|
|
group) |
|
|
|
|
|
self.group = check_int_positive(group) |
|
|
|
|
|
if self.group <= 1: |
|
|
|
|
|
raise ValueError("the number of group must be greater than 1.") |
|
|
def _check_data_dim(self, x): |
|
|
def _check_data_dim(self, x): |
|
|
if x.dim == 0: |
|
|
if x.dim == 0: |
|
|
pass |
|
|
pass |
|
|
|