From: @liu_xiao_93 Reviewed-by: @liangchenghui Signed-off-by: @liangchenghuitags/v1.1.0
| @@ -53,7 +53,7 @@ from .math_ops import (Abs, ACos, Asin, Asinh, AddN, AccumulateNV2, AssignAdd, A | |||||
| NPUAllocFloatStatus, NPUClearFloatStatus, | NPUAllocFloatStatus, NPUClearFloatStatus, | ||||
| NPUGetFloatStatus, Pow, RealDiv, IsNan, IsInf, IsFinite, FloatStatus, | NPUGetFloatStatus, Pow, RealDiv, IsNan, IsInf, IsFinite, FloatStatus, | ||||
| Reciprocal, CumSum, HistogramFixedWidth, SquaredDifference, Xdivy, Xlogy, | Reciprocal, CumSum, HistogramFixedWidth, SquaredDifference, Xdivy, Xlogy, | ||||
| Sin, Sqrt, Rsqrt, BesselI0e, BesselI1e, TruncateDiv, TruncateMod, IFMR, | |||||
| Sin, Sqrt, Rsqrt, BesselI0e, BesselI1e, TruncateDiv, TruncateMod, | |||||
| Square, Sub, TensorAdd, Sign, Round, SquareSumAll, Atan, Atanh, Cosh, Sinh, Eps, Tan, TensorDot) | Square, Sub, TensorAdd, Sign, Round, SquareSumAll, Atan, Atanh, Cosh, Sinh, Eps, Tan, TensorDot) | ||||
| from .random_ops import (RandomChoiceWithMask, StandardNormal, Gamma, Poisson, UniformInt, UniformReal, | from .random_ops import (RandomChoiceWithMask, StandardNormal, Gamma, Poisson, UniformInt, UniformReal, | ||||
| @@ -98,7 +98,6 @@ __all__ = [ | |||||
| 'EditDistance', | 'EditDistance', | ||||
| 'CropAndResize', | 'CropAndResize', | ||||
| 'TensorAdd', | 'TensorAdd', | ||||
| 'IFMR', | |||||
| 'Argmax', | 'Argmax', | ||||
| 'Argmin', | 'Argmin', | ||||
| 'ArgMaxWithValue', | 'ArgMaxWithValue', | ||||
| @@ -43,7 +43,8 @@ __all__ = ["MinMaxUpdatePerLayer", | |||||
| "BatchNormFoldGradD", | "BatchNormFoldGradD", | ||||
| "BatchNormFold2_D", | "BatchNormFold2_D", | ||||
| "BatchNormFold2GradD", | "BatchNormFold2GradD", | ||||
| "BatchNormFold2GradReduce" | |||||
| "BatchNormFold2GradReduce", | |||||
| "IFMR" | |||||
| ] | ] | ||||
| @@ -1384,3 +1385,66 @@ class WtsARQ(PrimitiveWithInfer): | |||||
| validator.check_tensor_type_same({"w_min": w_min_dtype}, valid_types, self.name) | validator.check_tensor_type_same({"w_min": w_min_dtype}, valid_types, self.name) | ||||
| validator.check_tensor_type_same({"w_max": w_max_dtype}, valid_types, self.name) | validator.check_tensor_type_same({"w_max": w_max_dtype}, valid_types, self.name) | ||||
| return w_dtype | return w_dtype | ||||
| class IFMR(PrimitiveWithInfer): | |||||
| """ | |||||
| The TFMR(Input Feature Map Reconstruction). | |||||
| Args: | |||||
| min_percentile (float): Min init percentile. Default: 0.999999. | |||||
| max_percentile (float): Max init percentile. Default: 0.999999. | |||||
| search_range Union[list(float), tuple(float)]: Range of searching. Default: [0.7, 1.3]. | |||||
| search_step (float): Step size of searching. Default: 0.01. | |||||
| with_offset (bool): Whether using offset. Default: True. | |||||
| Inputs: | |||||
| - **data** (Tensor) - A Tensor of feature map. With float16 or float32 data type. | |||||
| - **data_min** (Tensor) - A Tensor of min value of feature map, the shape is :math:`(1)`. | |||||
| With float16 or float32 data type. | |||||
| - **data_max** (Tensor) - A Tensor of max value of feature map, the shape is :math:`(1)`. | |||||
| With float16 or float32 data type. | |||||
| - **cumsum** (Tensor) - A `1-D` Tensor of cumsum bin of data. With int32 data type. | |||||
| Outputs: | |||||
| - **scale** (Tensor) - A tensor of optimal scale, the shape is :math:`(1)`. Data dtype is float32. | |||||
| - **offset** (Tensor) - A tensor of optimal offset, the shape is :math:`(1)`. Data dtype is float32. | |||||
| Examples: | |||||
| >>> data = Tensor(np.random.rand(1, 3, 6, 4).astype(np.float32)) | |||||
| >>> data_min = Tensor([0.1], mstype.float32) | |||||
| >>> data_max = Tensor([0.5], mstype.float32) | |||||
| >>> cumsum = Tensor(np.random.rand(4).astype(np.int32)) | |||||
| >>> ifmr = Q.IFMR(min_percentile=0.2, max_percentile=0.9, search_range=(1.0, 2.0), | |||||
| >>> search_step=1.0, with_offset=False) | |||||
| >>> output = ifmr(data, data_min, data_max, cumsum) | |||||
| ([7.87401572e-03], [0.00000000e+00]) | |||||
| """ | |||||
| @prim_attr_register | |||||
| def __init__(self, min_percentile=0.999999, max_percentile=0.999999, search_range=(0.7, 1.3), search_step=0.01, | |||||
| with_offset=True): | |||||
| validator.check_value_type("min_percentile", min_percentile, [float], self.name) | |||||
| validator.check_value_type("max_percentile", max_percentile, [float], self.name) | |||||
| validator.check_value_type("search_range", search_range, [list, tuple], self.name) | |||||
| for item in search_range: | |||||
| validator.check_positive_float(item, "item of search_range", self.name) | |||||
| validator.check('search_range[1]', search_range[1], 'search_range[0]', search_range[0], Rel.GE, self.name) | |||||
| validator.check_value_type("search_step", search_step, [float], self.name) | |||||
| validator.check_value_type("offset_flag", with_offset, [bool], self.name) | |||||
| def infer_shape(self, data_shape, data_min_shape, data_max_shape, cumsum_shape): | |||||
| validator.check_equal_int(len(data_min_shape), 1, "dims of data_min", self.name) | |||||
| validator.check_equal_int(data_min_shape[0], 1, "data_min[0]", self.name) | |||||
| validator.check_equal_int(len(data_max_shape), 1, "dims of data_max", self.name) | |||||
| validator.check_equal_int(data_max_shape[0], 1, "data_max[0]", self.name) | |||||
| validator.check_equal_int(len(cumsum_shape), 1, "dims of cumsum", self.name) | |||||
| return (1,), (1,) | |||||
| def infer_dtype(self, data_dtype, data_min_dtype, data_max_dtype, cumsum_dtype): | |||||
| tuple(map(partial(validator.check_tensor_dtype_valid, | |||||
| valid_dtypes=(mstype.float16, mstype.float32), prim_name=self.name), | |||||
| ("input_value", "input_min", "input_max"), | |||||
| (data_dtype, data_min_dtype, data_max_dtype))) | |||||
| validator.check_tensor_dtype_valid("input_bins", cumsum_dtype, [mstype.int32], self.name) | |||||
| return mstype.tensor_type(mstype.float32), mstype.tensor_type(mstype.float32) | |||||
| @@ -16,7 +16,6 @@ | |||||
| """Operators for math.""" | """Operators for math.""" | ||||
| import copy | import copy | ||||
| from functools import partial | |||||
| import numpy as np | import numpy as np | ||||
| from ... import context | from ... import context | ||||
| @@ -3680,66 +3679,3 @@ class Eps(PrimitiveWithInfer): | |||||
| 'dtype': input_x['dtype'], | 'dtype': input_x['dtype'], | ||||
| } | } | ||||
| return out | return out | ||||
| class IFMR(PrimitiveWithInfer): | |||||
| """ | |||||
| The TFMR(Input Feature Map Reconstruction). | |||||
| Args: | |||||
| min_percentile (float): Min init percentile. Default: 0.999999. | |||||
| max_percentile (float): Max init percentile. Default: 0.999999. | |||||
| search_range Union[list(float), tuple(float)]: Range of searching. Default: [0.7, 1.3]. | |||||
| search_step (float): Step size of searching. Default: 0.01. | |||||
| with_offset (bool): Whether using offset. Default: True. | |||||
| Inputs: | |||||
| - **data** (Tensor) - A Tensor of feature map. With float16 or float32 data type. | |||||
| - **data_min** (Tensor) - A Tensor of min value of feature map, the shape is :math:`(1)`. | |||||
| With float16 or float32 data type. | |||||
| - **data_max** (Tensor) - A Tensor of max value of feature map, the shape is :math:`(1)`. | |||||
| With float16 or float32 data type. | |||||
| - **cumsum** (Tensor) - A `1-D` Tensor of cumsum bin of data. With int32 data type. | |||||
| Outputs: | |||||
| - **scale** (Tensor) - A tensor of optimal scale, the shape is :math:`(1)`. Data dtype is float32. | |||||
| - **offset** (Tensor) - A tensor of optimal offset, the shape is :math:`(1)`. Data dtype is float32. | |||||
| Examples: | |||||
| >>> data = Tensor(np.random.rand(1, 3, 6, 4).astype(np.float32)) | |||||
| >>> data_min = Tensor([0.1], mstype.float32) | |||||
| >>> data_max = Tensor([0.5], mstype.float32) | |||||
| >>> cumsum = Tensor(np.random.rand(4).astype(np.int32)) | |||||
| >>> ifmr = P.IFMR(min_percentile=0.2, max_percentile=0.9, search_range=(1.0, 2.0), | |||||
| search_step=1.0, with_offset=False) | |||||
| >>> output = ifmr(data, data_min, data_max, cumsum) | |||||
| ([7.87401572e-03], [0.00000000e+00]) | |||||
| """ | |||||
| @prim_attr_register | |||||
| def __init__(self, min_percentile=0.999999, max_percentile=0.999999, search_range=(0.7, 1.3), search_step=0.01, | |||||
| with_offset=True): | |||||
| validator.check_value_type("min_percentile", min_percentile, [float], self.name) | |||||
| validator.check_value_type("max_percentile", max_percentile, [float], self.name) | |||||
| validator.check_value_type("search_range", search_range, [list, tuple], self.name) | |||||
| for item in search_range: | |||||
| validator.check_positive_float(item, "item of search_range", self.name) | |||||
| validator.check('search_range[1]', search_range[1], 'search_range[0]', search_range[0], Rel.GE, self.name) | |||||
| validator.check_value_type("search_step", search_step, [float], self.name) | |||||
| validator.check_value_type("offset_flag", with_offset, [bool], self.name) | |||||
| def infer_shape(self, data_shape, data_min_shape, data_max_shape, cumsum_shape): | |||||
| validator.check_equal_int(len(data_min_shape), 1, "dims of data_min", self.name) | |||||
| validator.check_equal_int(data_min_shape[0], 1, "data_min[0]", self.name) | |||||
| validator.check_equal_int(len(data_max_shape), 1, "dims of data_max", self.name) | |||||
| validator.check_equal_int(data_max_shape[0], 1, "data_max[0]", self.name) | |||||
| validator.check_equal_int(len(cumsum_shape), 1, "dims of cumsum", self.name) | |||||
| return (1,), (1,) | |||||
| def infer_dtype(self, data_dtype, data_min_dtype, data_max_dtype, cumsum_dtype): | |||||
| tuple(map(partial(validator.check_tensor_dtype_valid, | |||||
| valid_dtypes=(mstype.float16, mstype.float32), prim_name=self.name), | |||||
| ("input_value", "input_min", "input_max"), | |||||
| (data_dtype, data_min_dtype, data_max_dtype))) | |||||
| validator.check_tensor_dtype_valid("input_bins", cumsum_dtype, [mstype.int32], self.name) | |||||
| return mstype.tensor_type(mstype.float32), mstype.tensor_type(mstype.float32) | |||||
| @@ -601,10 +601,10 @@ class FusedBatchNorm(Primitive): | |||||
| Inputs: | Inputs: | ||||
| - **input_x** (Tensor) - Tensor of shape :math:`(N, C)`. | - **input_x** (Tensor) - Tensor of shape :math:`(N, C)`. | ||||
| - **scale** (Tensor) - Tensor of shape :math:`(C,)`. | |||||
| - **bias** (Tensor) - Tensor of shape :math:`(C,)`. | |||||
| - **mean** (Tensor) - Tensor of shape :math:`(C,)`. | |||||
| - **variance** (Tensor) - Tensor of shape :math:`(C,)`. | |||||
| - **scale** (Parameter) - Tensor of shape :math:`(C,)`. | |||||
| - **bias** (Parameter) - Tensor of shape :math:`(C,)`. | |||||
| - **mean** (Parameter) - Tensor of shape :math:`(C,)`. | |||||
| - **variance** (Parameter) - Tensor of shape :math:`(C,)`. | |||||
| Outputs: | Outputs: | ||||
| Tuple of 5 Tensor, the normalized input and the updated parameters. | Tuple of 5 Tensor, the normalized input and the updated parameters. | ||||
| @@ -616,13 +616,30 @@ class FusedBatchNorm(Primitive): | |||||
| - **updated_moving_variance** (Tensor) - Tensor of shape :math:`(C,)`. | - **updated_moving_variance** (Tensor) - Tensor of shape :math:`(C,)`. | ||||
| Examples: | Examples: | ||||
| >>> import mindspore | |||||
| >>> import mindspore.nn as nn | |||||
| >>> import numpy as np | |||||
| >>> from mindspore import Parameter | |||||
| >>> from mindspore import Tensor | |||||
| >>> from mindspore.ops import operations as P | |||||
| >>> class FusedBatchNormNet(nn.Cell): | |||||
| >>> def __init__(self): | |||||
| >>> super(FusedBatchNormNet, self).__init__() | |||||
| >>> self.fused_batch_norm = P.FusedBatchNorm() | |||||
| >>> self.scale = Parameter(Tensor(np.ones([64]), mindspore.float32), name="scale") | |||||
| >>> self.bias = Parameter(Tensor(np.ones([64]), mindspore.float32), name="bias") | |||||
| >>> self.mean = Parameter(Tensor(np.ones([64]), mindspore.float32), name="mean") | |||||
| >>> self.variance = Parameter(Tensor(np.ones([64]), mindspore.float32), name="variance") | |||||
| >>> | |||||
| >>> def construct(self, input_x): | |||||
| >>> out = self.fused_batch_norm(input_x, self.scale, self.bias, self.mean, self.variance) | |||||
| >>> return out | |||||
| >>> | |||||
| >>> input_x = Tensor(np.ones([128, 64, 32, 64]), mindspore.float32) | >>> input_x = Tensor(np.ones([128, 64, 32, 64]), mindspore.float32) | ||||
| >>> scale = Tensor(np.ones([64]), mindspore.float32) | |||||
| >>> bias = Tensor(np.ones([64]), mindspore.float32) | |||||
| >>> mean = Tensor(np.ones([64]), mindspore.float32) | |||||
| >>> variance = Tensor(np.ones([64]), mindspore.float32) | |||||
| >>> op = P.FusedBatchNorm() | |||||
| >>> output = op(input_x, scale, bias, mean, variance) | |||||
| >>> net = FusedBatchNormNet() | |||||
| >>> output = net(input_x) | |||||
| >>> output[0].shape | |||||
| (128, 64, 32, 64) | |||||
| """ | """ | ||||
| __mindspore_signature__ = ( | __mindspore_signature__ = ( | ||||
| sig.make_sig('input_x', dtype=sig.sig_dtype.T2), | sig.make_sig('input_x', dtype=sig.sig_dtype.T2), | ||||
| @@ -673,12 +690,12 @@ class FusedBatchNormEx(PrimitiveWithInfer): | |||||
| Inputs: | Inputs: | ||||
| - **input_x** (Tensor) - The input of FusedBatchNormEx, Tensor of shape :math:`(N, C)`, | - **input_x** (Tensor) - The input of FusedBatchNormEx, Tensor of shape :math:`(N, C)`, | ||||
| data type: float16 or float32. | data type: float16 or float32. | ||||
| - **scale** (Tensor) - Parameter scale, same with gamma above-mentioned, Tensor of shape :math:`(C,)`, | |||||
| - **scale** (Parameter) - Parameter scale, same with gamma above-mentioned, Tensor of shape :math:`(C,)`, | |||||
| data type: float32. | data type: float32. | ||||
| - **bias** (Tensor) - Parameter bias, same with beta above-mentioned, Tensor of shape :math:`(C,)`, | |||||
| - **bias** (Parameter) - Parameter bias, same with beta above-mentioned, Tensor of shape :math:`(C,)`, | |||||
| data type: float32. | data type: float32. | ||||
| - **mean** (Tensor) - mean value, Tensor of shape :math:`(C,)`, data type: float32. | |||||
| - **variance** (Tensor) - variance value, Tensor of shape :math:`(C,)`, data type: float32. | |||||
| - **mean** (Parameter) - mean value, Tensor of shape :math:`(C,)`, data type: float32. | |||||
| - **variance** (Parameter) - variance value, Tensor of shape :math:`(C,)`, data type: float32. | |||||
| Outputs: | Outputs: | ||||
| Tuple of 6 Tensors, the normalized input, the updated parameters and reserve. | Tuple of 6 Tensors, the normalized input, the updated parameters and reserve. | ||||
| @@ -692,13 +709,30 @@ class FusedBatchNormEx(PrimitiveWithInfer): | |||||
| - **reserve** (Tensor) - reserve space, Tensor of shape :math:`(C,)`, data type: float32. | - **reserve** (Tensor) - reserve space, Tensor of shape :math:`(C,)`, data type: float32. | ||||
| Examples: | Examples: | ||||
| >>> import mindspore | |||||
| >>> import mindspore.nn as nn | |||||
| >>> import numpy as np | |||||
| >>> from mindspore import Parameter | |||||
| >>> from mindspore import Tensor | |||||
| >>> from mindspore.ops import operations as P | |||||
| >>> class FusedBatchNormExNet(nn.Cell): | |||||
| >>> def __init__(self): | |||||
| >>> super(FusedBatchNormExNet, self).__init__() | |||||
| >>> self.fused_batch_norm_ex = P.FusedBatchNormEx() | |||||
| >>> self.scale = Parameter(Tensor(np.ones([64]), mindspore.float32), name="scale") | |||||
| >>> self.bias = Parameter(Tensor(np.ones([64]), mindspore.float32), name="bias") | |||||
| >>> self.mean = Parameter(Tensor(np.ones([64]), mindspore.float32), name="mean") | |||||
| >>> self.variance = Parameter(Tensor(np.ones([64]), mindspore.float32), name="variance") | |||||
| >>> | |||||
| >>> def construct(self, input_x): | |||||
| >>> out = self.fused_batch_norm_ex(input_x, self.scale, self.bias, self.mean, self.variance) | |||||
| >>> return out | |||||
| >>> | |||||
| >>> input_x = Tensor(np.ones([128, 64, 32, 64]), mindspore.float32) | >>> input_x = Tensor(np.ones([128, 64, 32, 64]), mindspore.float32) | ||||
| >>> scale = Tensor(np.ones([64]), mindspore.float32) | |||||
| >>> bias = Tensor(np.ones([64]), mindspore.float32) | |||||
| >>> mean = Tensor(np.ones([64]), mindspore.float32) | |||||
| >>> variance = Tensor(np.ones([64]), mindspore.float32) | |||||
| >>> op = P.FusedBatchNormEx() | |||||
| >>> output = op(input_x, scale, bias, mean, variance) | |||||
| >>> net = FusedBatchNormExNet() | |||||
| >>> output = net(input_x) | |||||
| >>> output[0].shape | |||||
| (128, 64, 32, 64) | |||||
| """ | """ | ||||
| __mindspore_signature__ = ( | __mindspore_signature__ = ( | ||||
| sig.make_sig('input_x', dtype=sig.sig_dtype.T2), | sig.make_sig('input_x', dtype=sig.sig_dtype.T2), | ||||
| @@ -756,7 +790,7 @@ class BNTrainingReduce(PrimitiveWithInfer): | |||||
| Examples: | Examples: | ||||
| >>> input_x = Tensor(np.ones([128, 64, 32, 64]), mindspore.float32) | >>> input_x = Tensor(np.ones([128, 64, 32, 64]), mindspore.float32) | ||||
| >>> bn_training_reduce = P.BNTrainingReduce(input_x) | |||||
| >>> bn_training_reduce = P.BNTrainingReduce() | |||||
| >>> output = bn_training_reduce(input_x) | >>> output = bn_training_reduce(input_x) | ||||
| """ | """ | ||||
| @@ -5662,13 +5696,30 @@ class DynamicRNN(PrimitiveWithInfer): | |||||
| Has the same type with input `b`. | Has the same type with input `b`. | ||||
| Examples: | Examples: | ||||
| >>> import mindspore | |||||
| >>> import mindspore.nn as nn | |||||
| >>> import numpy as np | |||||
| >>> from mindspore import Parameter | |||||
| >>> from mindspore import Tensor | |||||
| >>> from mindspore.ops import operations as P | |||||
| >>> import mindspore.context as context | |||||
| >>> context.set_context(mode=context.GRAPH_MODE) | |||||
| >>> class DynamicRNNNet(nn.Cell): | |||||
| >>> def __init__(self): | |||||
| >>> super(DynamicRNNNet, self).__init__() | |||||
| >>> self.dynamic_rnn = P.DynamicRNN() | |||||
| >>> | |||||
| >>> def construct(self, x, w, b, init_h, init_c): | |||||
| >>> out = self.dynamic_rnn(x, w, b, None, init_h, init_c) | |||||
| >>> return out | |||||
| >>> | |||||
| >>> x = Tensor(np.random.rand(2, 16, 64).astype(np.float16)) | >>> x = Tensor(np.random.rand(2, 16, 64).astype(np.float16)) | ||||
| >>> w = Tensor(np.random.rand(96, 128).astype(np.float16)) | >>> w = Tensor(np.random.rand(96, 128).astype(np.float16)) | ||||
| >>> b = Tensor(np.random.rand(128).astype(np.float16)) | >>> b = Tensor(np.random.rand(128).astype(np.float16)) | ||||
| >>> init_h = Tensor(np.random.rand(1, 16, 32).astype(np.float16)) | >>> init_h = Tensor(np.random.rand(1, 16, 32).astype(np.float16)) | ||||
| >>> init_c = Tensor(np.random.rand(1, 16, 32).astype(np.float16)) | >>> init_c = Tensor(np.random.rand(1, 16, 32).astype(np.float16)) | ||||
| >>> dynamic_rnn = P.DynamicRNN() | |||||
| >>> output = dynamic_rnn(x, w, b, None, init_h, init_c) | |||||
| >>> net = DynamicRNNNet() | |||||
| >>> output = net(x, w, b, init_h, init_c) | |||||
| >>> output[0].shape | >>> output[0].shape | ||||
| (2, 16, 32) | (2, 16, 32) | ||||
| """ | """ | ||||
| @@ -1446,7 +1446,7 @@ test_case_math_ops = [ | |||||
| 'desc_inputs': [[3, 4, 5], [2, 3, 4, 5]], | 'desc_inputs': [[3, 4, 5], [2, 3, 4, 5]], | ||||
| 'desc_bprop': [[2, 3, 4, 5]]}), | 'desc_bprop': [[2, 3, 4, 5]]}), | ||||
| ('IFMR', { | ('IFMR', { | ||||
| 'block': P.IFMR(min_percentile=0.2, max_percentile=0.9, search_range=(1.0, 2.0), | |||||
| 'block': Q.IFMR(min_percentile=0.2, max_percentile=0.9, search_range=(1.0, 2.0), | |||||
| search_step=1.0, with_offset=False), | search_step=1.0, with_offset=False), | ||||
| 'desc_inputs': [[3, 4, 5], Tensor([0.1], mstype.float32), Tensor([0.9], mstype.float32), | 'desc_inputs': [[3, 4, 5], Tensor([0.1], mstype.float32), Tensor([0.9], mstype.float32), | ||||
| Tensor(np.random.rand(4).astype(np.int32))], | Tensor(np.random.rand(4).astype(np.int32))], | ||||