|
|
|
@@ -714,15 +714,20 @@ class FusedBatchNormEx(PrimitiveWithInfer): |
|
|
|
|
|
|
|
class BNTrainingReduce(PrimitiveWithInfer): |
|
|
|
""" |
|
|
|
reduce sum at axis [0, 2, 3]. |
|
|
|
For BatchNorm operator, this operator update the moving averages for training and is used in conjunction with |
|
|
|
BNTrainingUpdate. |
|
|
|
|
|
|
|
Inputs: |
|
|
|
- **x** (Tensor) - Tensor of shape :math:`(N, C)`. |
|
|
|
- **x** (Tensor) - A 4-D Tensor with float16 or float32 data type. Tensor of shape :math:`(N, C, A, B)`. |
|
|
|
|
|
|
|
Outputs: |
|
|
|
- **sum** (Tensor) - Tensor of shape :math:`(C,)`. |
|
|
|
- **square_sum** (Tensor) - Tensor of shape :math:`(C,)`. |
|
|
|
- **sum** (Tensor) - A 1-D Tensor with float32 data type. Tensor of shape :math:`(C,)`. |
|
|
|
- **square_sum** (Tensor) - A 1-D Tensor with float32 data type. Tensor of shape :math:`(C,)`. |
|
|
|
|
|
|
|
Examples: |
|
|
|
>>> input_x = Tensor(np.ones([128, 64, 32, 64]), mindspore.float32) |
|
|
|
>>> bn_training_reduce = P.BNTrainingReduce(input_x) |
|
|
|
>>> output = bn_training_reduce(input_x) |
|
|
|
""" |
|
|
|
|
|
|
|
@prim_attr_register |
|
|
|
@@ -734,24 +739,90 @@ class BNTrainingReduce(PrimitiveWithInfer): |
|
|
|
return ([x_shape[1]], [x_shape[1]]) |
|
|
|
|
|
|
|
def infer_dtype(self, x_type): |
|
|
|
validator.check_tensor_type_same({"x_type": x_type}, [mstype.float16, mstype.float32], self.name) |
|
|
|
return (x_type, x_type) |
|
|
|
|
|
|
|
|
|
|
|
class BNTrainingUpdate(PrimitiveWithInfer): |
|
|
|
""" |
|
|
|
The primitive operator of the register and info descriptor in bn_training_update. |
|
|
|
For BatchNorm operator, this operator update the moving averages for training and is used in conjunction with |
|
|
|
BNTrainingReduce. |
|
|
|
|
|
|
|
Args: |
|
|
|
isRef (bool): If a ref. Default: True. |
|
|
|
epsilon (float): A small value added to variance avoid dividing by zero. Default: 1e-5. |
|
|
|
factor (float): A weight for updating the mean and variance. Default: 0.1. |
|
|
|
|
|
|
|
Inputs: |
|
|
|
- **x** (Tensor) - A 4-D Tensor with float16 or float32 data type. Tensor of shape :math:`(N, C, A, B)`. |
|
|
|
- **sum** (Tensor) - A 1-D Tensor with float16 or float32 data type for the output of operator BNTrainingReduce. |
|
|
|
Tensor of shape :math:`(C,)`. |
|
|
|
- **square_sum** (Tensor) - A 1-D Tensor with float16 or float32 data type for the output of operator |
|
|
|
BNTrainingReduce. Tensor of shape :math:`(C,)`. |
|
|
|
- **scale** (Tensor) - A 1-D Tensor with float16 or float32, for the scaling factor. |
|
|
|
Tensor of shape :math:`(C,)`. |
|
|
|
- **offset** (Tensor) - A 1-D Tensor with float16 or float32, for the scaling offset. |
|
|
|
Tensor of shape :math:`(C,)`. |
|
|
|
- **mean** (Tensor) - A 1-D Tensor with float16 or float32, for the scaling mean. Tensor of shape :math:`(C,)`. |
|
|
|
- **variance** (Tensor) - A 1-D Tensor with float16 or float32, for the update variance. |
|
|
|
Tensor of shape :math:`(C,)`. |
|
|
|
|
|
|
|
Outputs: |
|
|
|
- **y** (Tensor) - Tensor, has the same shape data type as `x`. |
|
|
|
- **mean** (Tensor) - Tensor for the updated mean, with float32 data type. |
|
|
|
Has the same shape as `variance`. |
|
|
|
- **variance** (Tensor) - Tensor for the updated variance, with float32 data type. |
|
|
|
Has the same shape as `variance`. |
|
|
|
- **batch_mean** (Tensor) - Tensor for the mean of `x`, with float32 data type. |
|
|
|
Has the same shape as `variance`. |
|
|
|
- **batch_variance** (Tensor) - Tensor for the mean of `variance`, with float32 data type. |
|
|
|
Has the same shape as `variance`. |
|
|
|
|
|
|
|
Examples: |
|
|
|
>>> input_x = Tensor(np.ones([128, 64, 32, 64]), mindspore.float32) |
|
|
|
>>> sum = Tensor(np.ones([64]), mindspore.float32) |
|
|
|
>>> square_sum = Tensor(np.ones([64]), mindspore.float32) |
|
|
|
>>> scale = Tensor(np.ones([64]), mindspore.float32) |
|
|
|
>>> offset = Tensor(np.ones([64]), mindspore.float32) |
|
|
|
>>> mean = Tensor(np.ones([64]), mindspore.float32) |
|
|
|
>>> variance = Tensor(np.ones([64]), mindspore.float32) |
|
|
|
>>> bn_training_update = P.BNTrainingUpdate() |
|
|
|
>>> output = bn_training_update(input_x, sum, square_sum, scale, offset, mean, variance) |
|
|
|
""" |
|
|
|
@prim_attr_register |
|
|
|
def __init__(self, isRef=True, epsilon=1e-5, factor=0.1): |
|
|
|
self.init_prim_io_names(inputs=['x', 'sum', 'square_sum', 'scale', 'b', 'mean', 'variance'], |
|
|
|
outputs=['y', 'running_mean', 'running_variance', 'save_mean', 'save_inv_variance']) |
|
|
|
validator.check_value_type("isRef", isRef, [bool], self.name) |
|
|
|
validator.check_value_type("epsilon", epsilon, [float], self.name) |
|
|
|
validator.check_value_type("factor", factor, [float], self.name) |
|
|
|
self.epsilon = validator.check_number_range('epsilon', epsilon, 0, 1, Rel.INC_RIGHT, 'BNTrainingUpdate') |
|
|
|
self.factor = validator.check_number_range('factor', factor, 0, 1, Rel.INC_BOTH, 'BNTrainingUpdate') |
|
|
|
|
|
|
|
def infer_shape(self, x, sum, square_sum, scale, b, mean, variance): |
|
|
|
validator.check_integer("x rank", len(x), 4, Rel.EQ, self.name) |
|
|
|
validator.check_integer("sum rank", len(sum), 1, Rel.EQ, self.name) |
|
|
|
validator.check_integer("square_sum rank", len(square_sum), 1, Rel.EQ, self.name) |
|
|
|
validator.check_integer("scale rank", len(scale), 1, Rel.EQ, self.name) |
|
|
|
validator.check_integer("b rank", len(b), 1, Rel.EQ, self.name) |
|
|
|
validator.check_integer("mean rank", len(mean), 1, Rel.EQ, self.name) |
|
|
|
validator.check_integer("variance rank", len(variance), 1, Rel.EQ, self.name) |
|
|
|
validator.check("sum shape", sum, "x_shape[1]", x[1], Rel.EQ, self.name) |
|
|
|
validator.check("square_sum shape", square_sum, "sum", sum, Rel.EQ, self.name) |
|
|
|
validator.check("scale shape", scale, "x_shape[1]", x[1], Rel.EQ, self.name) |
|
|
|
validator.check("offset shape", b, "x_shape[1]", x[1], Rel.EQ, self.name) |
|
|
|
validator.check("mean shape", mean, "x_shape[1]", x[1], Rel.EQ, self.name) |
|
|
|
validator.check("variance shape", variance, "x_shape[1]", x[1], Rel.EQ, self.name) |
|
|
|
return (x, variance, variance, variance, variance) |
|
|
|
|
|
|
|
def infer_dtype(self, x, sum, square_sum, scale, b, mean, variance): |
|
|
|
validator.check_tensor_type_same({"x_type": x}, [mstype.float16, mstype.float32], self.name) |
|
|
|
validator.check_tensor_type_same({"sum_type": sum}, [mstype.float16, mstype.float32], self.name) |
|
|
|
validator.check_tensor_type_same({"square_sum_type": square_sum}, [mstype.float16, mstype.float32], self.name) |
|
|
|
validator.check_tensor_type_same({"scale_type": scale}, [mstype.float16, mstype.float32], self.name) |
|
|
|
validator.check_tensor_type_same({"b_type": b}, [mstype.float16, mstype.float32], self.name) |
|
|
|
validator.check_tensor_type_same({"mean_type": mean}, [mstype.float16, mstype.float32], self.name) |
|
|
|
validator.check_tensor_type_same({"variance_type": variance}, [mstype.float16, mstype.float32], self.name) |
|
|
|
return (x, variance, variance, variance, variance) |
|
|
|
|
|
|
|
|
|
|
|
|