| @@ -33,7 +33,6 @@ class _BatchNorm(Cell): | |||
| @cell_attr_register | |||
| def __init__(self, | |||
| num_features, | |||
| group=1, | |||
| eps=1e-5, | |||
| momentum=0.9, | |||
| affine=True, | |||
| @@ -41,7 +40,8 @@ class _BatchNorm(Cell): | |||
| beta_init='zeros', | |||
| moving_mean_init='zeros', | |||
| moving_var_init='ones', | |||
| use_batch_statistics=True): | |||
| use_batch_statistics=True, | |||
| group=1): | |||
| super(_BatchNorm, self).__init__() | |||
| if num_features < 1: | |||
| raise ValueError("num_features must be at least 1") | |||
| @@ -214,6 +214,25 @@ class BatchNorm1d(_BatchNorm): | |||
| >>> input = Tensor(np.random.randint(0, 255, [3, 16]), mindspore.float32) | |||
| >>> net(input) | |||
| """ | |||
| def __init__(self, | |||
| num_features, | |||
| eps=1e-5, | |||
| momentum=0.9, | |||
| affine=True, | |||
| gamma_init='ones', | |||
| beta_init='zeros', | |||
| moving_mean_init='zeros', | |||
| moving_var_init='ones', | |||
| use_batch_statistics=True): | |||
| super(BatchNorm1d, self).__init__(num_features, | |||
| eps, | |||
| momentum, | |||
| affine, | |||
| gamma_init, | |||
| beta_init, | |||
| moving_mean_init, | |||
| moving_var_init, | |||
| use_batch_statistics) | |||
| def _check_data_dim(self, x): | |||
| if x.dim() != 2: | |||
| pass | |||
| @@ -266,6 +285,25 @@ class BatchNorm2d(_BatchNorm): | |||
| >>> input = Tensor(np.random.randint(0, 255, [1, 3, 224, 224]), mindspore.float32) | |||
| >>> net(input) | |||
| """ | |||
| def __init__(self, | |||
| num_features, | |||
| eps=1e-5, | |||
| momentum=0.9, | |||
| affine=True, | |||
| gamma_init='ones', | |||
| beta_init='zeros', | |||
| moving_mean_init='zeros', | |||
| moving_var_init='ones', | |||
| use_batch_statistics=True): | |||
| super(BatchNorm2d, self).__init__(num_features, | |||
| eps, | |||
| momentum, | |||
| affine, | |||
| gamma_init, | |||
| beta_init, | |||
| moving_mean_init, | |||
| moving_var_init, | |||
| use_batch_statistics) | |||
| def _check_data_dim(self, x): | |||
| if x.dim() != 4: | |||
| pass | |||
| @@ -316,6 +354,30 @@ class GlobalBatchNorm(_BatchNorm): | |||
| >>> input = Tensor(np.random.randint(0, 255, [1, 3, 224, 224]), mindspore.float32) | |||
| >>> global_bn_op(input) | |||
| """ | |||
| def __init__(self, | |||
| num_features, | |||
| eps=1e-5, | |||
| momentum=0.9, | |||
| affine=True, | |||
| gamma_init='ones', | |||
| beta_init='zeros', | |||
| moving_mean_init='zeros', | |||
| moving_var_init='ones', | |||
| use_batch_statistics=True, | |||
| group=1): | |||
| super(GlobalBatchNorm, self).__init__(num_features, | |||
| eps, | |||
| momentum, | |||
| affine, | |||
| gamma_init, | |||
| beta_init, | |||
| moving_mean_init, | |||
| moving_var_init, | |||
| use_batch_statistics, | |||
| group) | |||
| self.group = check_int_positive(group) | |||
| if self.group <=1: | |||
| raise ValueError("the number of group must be greater than 1.") | |||
| def _check_data_dim(self, x): | |||
| if x.dim == 0: | |||
| pass | |||