|
|
|
@@ -625,7 +625,21 @@ class FusedBatchNorm(Primitive): |
|
|
|
|
|
|
|
class FusedBatchNormEx(PrimitiveWithInfer): |
|
|
|
r""" |
|
|
|
FusedBatchNormEx is an extension of FusedBatchNorm |
|
|
|
FusedBatchNormEx is an extension of FusedBatchNorm, FusedBatchNormEx has one more output(output reserve) |
|
|
|
than FusedBatchNorm, reserve will be used in backpropagation phase. FusedBatchNorm is a BatchNorm that |
|
|
|
moving mean and moving variance will be computed instead of being loaded. |
|
|
|
|
|
|
|
Batch Normalization is widely used in convolutional networks. This operation applies |
|
|
|
Batch Normalization over input to avoid internal covariate shift as described in the |
|
|
|
paper `Batch Normalization: Accelerating Deep Network Training by Reducing Internal |
|
|
|
Covariate Shift <https://arxiv.org/abs/1502.03167>`_. It rescales and recenters the |
|
|
|
feature using a mini-batch of data and the learned parameters which can be described |
|
|
|
in the following formula. |
|
|
|
|
|
|
|
.. math:: |
|
|
|
y = \frac{x - mean}{\sqrt{variance + \epsilon}} * \gamma + \beta |
|
|
|
|
|
|
|
where :math:`\gamma` is scale, :math:`\beta` is bias, :math:`\epsilon` is epsilon. |
|
|
|
|
|
|
|
Args: |
|
|
|
mode (int): Mode of batch normalization, value is 0 or 1. Default: 0. |
|
|
|
@@ -635,21 +649,25 @@ class FusedBatchNormEx(PrimitiveWithInfer): |
|
|
|
Momentum value should be [0, 1]. Default: 0.9. |
|
|
|
|
|
|
|
Inputs: |
|
|
|
- **input_x** (Tensor) - Tensor of shape :math:`(N, C)`. |
|
|
|
- **scale** (Tensor) - Tensor of shape :math:`(C,)`. |
|
|
|
- **bias** (Tensor) - Tensor of shape :math:`(C,)`. |
|
|
|
- **mean** (Tensor) - Tensor of shape :math:`(C,)`. |
|
|
|
- **variance** (Tensor) - Tensor of shape :math:`(C,)`. |
|
|
|
- **input_x** (Tensor) - The input of FusedBatchNormEx, Tensor of shape :math:`(N, C)`, |
|
|
|
data type: float16 or float32. |
|
|
|
- **scale** (Tensor) - Parameter scale, same with gamma above-mentioned, Tensor of shape :math:`(C,)`, |
|
|
|
data type: float32. |
|
|
|
- **bias** (Tensor) - Parameter bias, same with beta above-mentioned, Tensor of shape :math:`(C,)`, |
|
|
|
data type: float32. |
|
|
|
- **mean** (Tensor) - mean value, Tensor of shape :math:`(C,)`, data type: float32. |
|
|
|
- **variance** (Tensor) - variance value, Tensor of shape :math:`(C,)`, data type: float32. |
|
|
|
|
|
|
|
Outputs: |
|
|
|
Tuple of 6 Tensor, the normalized input and the updated parameters. |
|
|
|
|
|
|
|
- **output_x** (Tensor) - The same type and shape as the `input_x`. |
|
|
|
- **updated_scale** (Tensor) - Tensor of shape :math:`(C,)`. |
|
|
|
- **updated_bias** (Tensor) - Tensor of shape :math:`(C,)`. |
|
|
|
- **updated_moving_mean** (Tensor) - Tensor of shape :math:`(C,)`. |
|
|
|
- **updated_moving_variance** (Tensor) - Tensor of shape :math:`(C,)`. |
|
|
|
- **reserve** (Tensor) - Tensor of shape :math:`(C,)`. |
|
|
|
Tuple of 6 Tensor, the normalized input, the updated parameters and reserve. |
|
|
|
|
|
|
|
- **output_x** (Tensor) - The input of FusedBatchNormEx, same type and shape as the `input_x`. |
|
|
|
- **updated_scale** (Tensor) - Updated parameter scale, Tensor of shape :math:`(C,)`, data type: float32. |
|
|
|
- **updated_bias** (Tensor) - Updated parameter bias, Tensor of shape :math:`(C,)`, data type: float32. |
|
|
|
- **updated_moving_mean** (Tensor) - Updated mean value, Tensor of shape :math:`(C,)`, data type: float32. |
|
|
|
- **updated_moving_variance** (Tensor) - Updated variance value, Tensor of shape :math:`(C,)`, |
|
|
|
data type: float32. |
|
|
|
- **reserve** (Tensor) - reserve space, Tensor of shape :math:`(C,)`, data type: float32. |
|
|
|
|
|
|
|
Examples: |
|
|
|
>>> input_x = Tensor(np.ones([128, 64, 32, 64]), mindspore.float32) |
|
|
|
|