|
|
|
@@ -894,11 +894,7 @@ class InstanceNorm2d(Cell): |
|
|
|
\gamma and \beta are learnable parameter vectors of size num_features if affine is True. The standard-deviation |
|
|
|
is calculated via the biased estimator. |
|
|
|
|
|
|
|
By default, this layer uses instance statistics computed from input data in both training and evaluation modes. |
|
|
|
|
|
|
|
If use_batch_statistics is set to True, it means training phases, and this layer keeps running estimates of its |
|
|
|
computed mean and variance, which are then used for normalization during evaluation. The running estimates are |
|
|
|
kept with a default momentum of 0.1. |
|
|
|
This layer uses instance statistics computed from input data in both training and evaluation modes. |
|
|
|
|
|
|
|
InstanceNorm2d and BatchNorm2d are very similar, but have some differences. InstanceNorm2d is applied on each |
|
|
|
channel of channeled data like RGB images, but BatchNorm2d is usually applied on each batch of batched data. |
|
|
|
@@ -918,12 +914,6 @@ class InstanceNorm2d(Cell): |
|
|
|
The values of str refer to the function `initializer` including 'zeros', 'ones', etc. Default: 'ones'. |
|
|
|
beta_init (Union[Tensor, str, Initializer, numbers.Number]): Initializer for the beta weight. |
|
|
|
The values of str refer to the function `initializer` including 'zeros', 'ones', etc. Default: 'zeros'. |
|
|
|
moving_mean_init (Union[Tensor, str, Initializer, numbers.Number]): Initializer for the moving mean. |
|
|
|
The values of str refer to the function `initializer` including 'zeros', 'ones', etc. Default: 'zeros'. |
|
|
|
moving_var_init (Union[Tensor, str, Initializer, numbers.Number]): Initializer for the moving variance. |
|
|
|
The values of str refer to the function `initializer` including 'zeros', 'ones', etc. Default: 'ones'. |
|
|
|
use_batch_statistics (bool): If true, use the mean value and variance value of current batch data. If false, |
|
|
|
use the mean value and variance value of specified value. Default: True. |
|
|
|
|
|
|
|
Inputs: |
|
|
|
- **input** (Tensor) - Tensor of shape :math:`(N, C, H, W)`. Data type: float16 or float32. |
|
|
|
@@ -940,12 +930,12 @@ class InstanceNorm2d(Cell): |
|
|
|
TypeError: If `eps` is not a float. |
|
|
|
TypeError: If `momentum` is not a float. |
|
|
|
TypeError: If `affine` is not a bool. |
|
|
|
TypeError: If the type of `gamma_init`/`beta_init`/`moving_mean_init`/`moving_var_init` is not same, or if |
|
|
|
the initialized element type is not float32. |
|
|
|
TypeError: If the type of `gamma_init`/`beta_init` is not same, or if the initialized element type is not |
|
|
|
float32. |
|
|
|
ValueError: If `num_features` is less than 1. |
|
|
|
ValueError: If `momentum` is not in range [0, 1]. |
|
|
|
KeyError: If any of `gamma_init`/`beta_init`/`moving_mean_init`/`moving_var_init` is str and the homonymous |
|
|
|
class inheriting from `Initializer` not exists. |
|
|
|
KeyError: If any of `gamma_init`/`beta_init` is str and the homonymous class inheriting from `Initializer` not |
|
|
|
exists. |
|
|
|
|
|
|
|
Examples: |
|
|
|
>>> import mindspore |
|
|
|
@@ -966,31 +956,24 @@ class InstanceNorm2d(Cell): |
|
|
|
momentum=0.1, |
|
|
|
affine=True, |
|
|
|
gamma_init='ones', |
|
|
|
beta_init='zeros', |
|
|
|
moving_mean_init='zeros', |
|
|
|
moving_var_init='ones', |
|
|
|
use_batch_statistics=True): |
|
|
|
beta_init='zeros'): |
|
|
|
super(InstanceNorm2d, self).__init__() |
|
|
|
validator.check_value_type('num_features', num_features, [int], self.cls_name) |
|
|
|
validator.check_value_type('eps', eps, [float], self.cls_name) |
|
|
|
validator.check_value_type('momentum', momentum, [float], self.cls_name) |
|
|
|
validator.check_value_type('affine', affine, [bool], self.cls_name) |
|
|
|
args_input = {"gamma_init": gamma_init, "beta_init": beta_init, |
|
|
|
"moving_mean_init": moving_mean_init, "moving_var_init": moving_var_init} |
|
|
|
args_input = {"gamma_init": gamma_init, "beta_init": beta_init} |
|
|
|
self.check_types_valid(args_input, 'InstanceNorm2d') |
|
|
|
if num_features < 1: |
|
|
|
raise ValueError("num_features must be at least 1") |
|
|
|
|
|
|
|
if momentum < 0 or momentum > 1: |
|
|
|
raise ValueError("momentum should be a number in range [0, 1], but got {}".format(momentum)) |
|
|
|
self.use_batch_statistics = use_batch_statistics |
|
|
|
self.num_features = num_features |
|
|
|
self.eps = eps |
|
|
|
self.input_dims = '2d' |
|
|
|
self.moving_mean = Parameter(initializer( |
|
|
|
moving_mean_init, num_features), name="mean", requires_grad=False) |
|
|
|
self.moving_variance = Parameter(initializer( |
|
|
|
moving_var_init, num_features), name="variance", requires_grad=False) |
|
|
|
self.moving_mean = Parameter(initializer('zeros', num_features), name="mean", requires_grad=False) |
|
|
|
self.moving_variance = Parameter(initializer('ones', num_features), name="variance", requires_grad=False) |
|
|
|
self.gamma = Parameter(initializer( |
|
|
|
gamma_init, num_features), name="gamma", requires_grad=affine) |
|
|
|
self.beta = Parameter(initializer( |
|
|
|
@@ -998,9 +981,7 @@ class InstanceNorm2d(Cell): |
|
|
|
|
|
|
|
self.shape = P.Shape() |
|
|
|
self.momentum = momentum |
|
|
|
self.instance_bn = P.InstanceNorm(is_training=self.use_batch_statistics, |
|
|
|
epsilon=self.eps, |
|
|
|
momentum=self.momentum) |
|
|
|
self.instance_bn = P.InstanceNorm(epsilon=self.eps, momentum=self.momentum) |
|
|
|
|
|
|
|
def _check_data_dim(self, x): |
|
|
|
raise NotImplementedError |
|
|
|
|