|
|
|
@@ -1267,12 +1267,20 @@ class BatchNorm(PrimitiveWithInfer): |
|
|
|
Default: "NCHW". |
|
|
|
|
|
|
|
Inputs: |
|
|
|
If `is_training` is False, inputs are Tensors. |
|
|
|
- **input_x** (Tensor) - Tensor of shape :math:`(N, C)`, with float16 or float32 data type. |
|
|
|
- **scale** (Tensor) - Tensor of shape :math:`(C,)`, with float16 or float32 data type. |
|
|
|
- **bias** (Tensor) - Tensor of shape :math:`(C,)`, has the same data type with `scale`. |
|
|
|
- **mean** (Tensor) - Tensor of shape :math:`(C,)`, with float16 or float32 data type. |
|
|
|
- **variance** (Tensor) - Tensor of shape :math:`(C,)`, has the same data type with `mean`. |
|
|
|
|
|
|
|
If `is_training` is True, `scale`, `bias`, `mean` and `variance` are Parameters. |
|
|
|
- **input_x** (Tensor) - Tensor of shape :math:`(N, C)`, with float16 or float32 data type. |
|
|
|
- **scale** (Parameter) - Parameter of shape :math:`(C,)`, with float16 or float32 data type. |
|
|
|
- **bias** (Parameter) - Parameter of shape :math:`(C,)`, has the same data type with `scale`. |
|
|
|
- **mean** (Parameter) - Parameter of shape :math:`(C,)`, with float16 or float32 data type. |
|
|
|
- **variance** (Parameter) - Parameter of shape :math:`(C,)`, has the same data type with `mean`. |
|
|
|
|
|
|
|
Outputs: |
|
|
|
Tuple of 5 Tensor, the normalized inputs and the updated parameters. |
|
|
|
|
|
|
|
|