From: @david-he91 Reviewed-by: @liangchenghui,@guoqi1024 Signed-off-by: @liangchenghuitags/v1.2.0-rc1
| @@ -41,7 +41,7 @@ from .comm_ops import (AllGather, AllReduce, _AlltoAll, AllSwap, ReduceScatter, | |||
| from .debug_ops import (ImageSummary, InsertGradientOf, HookBackward, ScalarSummary, | |||
| TensorSummary, HistogramSummary, Print, Assert) | |||
| from .control_ops import ControlDepend, GeSwitch, Merge | |||
| from .inner_ops import ScalarCast, Randperm, NoRepeatNGram, LambApplyOptimizerAssign, LambApplyWeightAssign, MakeRefKey, Centralization | |||
| from .inner_ops import ScalarCast, Randperm, NoRepeatNGram, LambApplyOptimizerAssign, LambApplyWeightAssign, MakeRefKey | |||
| from .math_ops import (Abs, ACos, Asin, Asinh, AddN, AccumulateNV2, AssignAdd, AssignSub, Atan2, BatchMatMul, | |||
| BitwiseAnd, BitwiseOr, | |||
| @@ -22,7 +22,7 @@ from ...common import dtype as mstype | |||
| from ..primitive import PrimitiveWithCheck, PrimitiveWithInfer, prim_attr_register | |||
| from ..operations.math_ops import _infer_shape_reduce | |||
| from ...communication.management import GlobalComm | |||
| from .. import signature as sig | |||
| class ExtractImagePatches(PrimitiveWithInfer): | |||
| """ | |||
| @@ -815,3 +815,70 @@ class SyncBatchNorm(PrimitiveWithInfer): | |||
| args_moving = {"mean": mean, "variance": variance} | |||
| validator.check_tensors_dtypes_same_and_valid(args_moving, [mstype.float16, mstype.float32], self.name) | |||
| return (input_x, scale, bias, input_x, input_x) | |||
| class Centralization(PrimitiveWithInfer): | |||
| """ | |||
| Computes centralization. y = x - mean(x, axis). | |||
| Note: | |||
| The dimension index starts at 0 and must be in the range `[-input.ndim, input.ndim)`. | |||
| Inputs: | |||
| - **input_x** (Tensor) - The input tensor. The data type mast be float16 or float32. | |||
| - **axis** (Union[Int, Tuple(Int), List(Int)]) - The dimensions to reduce. Default: (), reduce all dimensions. | |||
| Only constant value is allowed. Must be in the range [-rank(input_x), rank(input_x)). | |||
| Outputs: | |||
| Tensor, has the same shape and dtype as the `input_x`. | |||
| Raises: | |||
| TypeError: If `axis` is not one of the following types: int, list, tuple, NoneType. | |||
| TypeError: If `axis` has non-Int elements. | |||
| Supported Platforms: | |||
| ``Ascend`` | |||
| Examples: | |||
| >>> mindspore.set_seed(1) | |||
| >>> input_x = Tensor(np.random.randn(2, 2).astype(np.float32)) | |||
| >>> centralization = ops.Centralization() | |||
| >>> output = centralization(input_x, -1) | |||
| >>> print(output) | |||
| [[ 1.1180509 -1.1180508] | |||
| [ 0.2723984 -0.2723984]] | |||
| """ | |||
| __mindspore_signature__ = ( | |||
| sig.make_sig('input_x'), | |||
| sig.make_sig('axis', default=()) | |||
| ) | |||
| @prim_attr_register | |||
| def __init__(self): | |||
| """Initialize Centralization""" | |||
| self.init_prim_io_names(inputs=['input_x', 'axis'], outputs=['output']) | |||
| def __infer__(self, input_x, axis): | |||
| x_shape = list(input_x['shape']) | |||
| x_dtype = input_x['dtype'] | |||
| axis_v = axis['value'] | |||
| rank = len(x_shape) | |||
| args = {'input_x': input_x['dtype']} | |||
| validator.check_tensors_dtypes_same_and_valid(args, [mstype.float16, mstype.float32], self.name) | |||
| if axis_v is None: | |||
| raise ValueError(f"For {self.name}, axis must be const.") | |||
| validator.check_value_type('axis', axis_v, [int, list, tuple], self.name) | |||
| if isinstance(axis_v, int): | |||
| validator.check_int_range(axis_v, -rank, rank, Rel.INC_LEFT, 'axis', self.name) | |||
| elif axis: | |||
| for index, one_axis in enumerate(axis_v): | |||
| validator.check_value_type('axis[%d]' % index, one_axis, [int], self.name) | |||
| out = {'shape': x_shape, | |||
| 'dtype': x_dtype, | |||
| 'value': None} | |||
| return out | |||
| @@ -21,7 +21,6 @@ from ..._checkparam import Rel | |||
| from ...common import dtype as mstype | |||
| from ...common.dtype import tensor, dtype_to_pytype | |||
| from ..primitive import prim_attr_register, Primitive, PrimitiveWithInfer | |||
| from .. import signature as sig | |||
| class ScalarCast(PrimitiveWithInfer): | |||
| @@ -358,70 +357,3 @@ class MakeRefKey(Primitive): | |||
| def __call__(self): | |||
| pass | |||
| class Centralization(PrimitiveWithInfer): | |||
| """ | |||
| Computes centralization. y = x - mean(x, axis). | |||
| Note: | |||
| The dimension index starts at 0 and must be in the range `[-input.ndim, input.ndim)`. | |||
| Inputs: | |||
| - **input_x** (Tensor) - The input tensor. The data type mast be float16 or float32. | |||
| - **axis** (Union[Int, Tuple(Int), List(Int)]) - The dimensions to reduce. Default: (), reduce all dimensions. | |||
| Only constant value is allowed. Must be in the range [-rank(input_x), rank(input_x)). | |||
| Outputs: | |||
| Tensor, has the same shape and dtype as the `input_x`. | |||
| Raises: | |||
| TypeError: If `axis` is not one of the following types: int, list, tuple, NoneType. | |||
| TypeError: If `axis` has non-Int elements. | |||
| Supported Platforms: | |||
| ``Ascend`` | |||
| Examples: | |||
| >>> mindspore.set_seed(1) | |||
| >>> input_x = Tensor(np.random.randn(2, 2).astype(np.float32)) | |||
| >>> centralization = ops.Centralization() | |||
| >>> output = centralization(input_x, -1) | |||
| >>> print(output) | |||
| [[ 1.1180509 -1.1180508] | |||
| [ 0.2723984 -0.2723984]] | |||
| """ | |||
| __mindspore_signature__ = ( | |||
| sig.make_sig('input_x'), | |||
| sig.make_sig('axis', default=()) | |||
| ) | |||
| @prim_attr_register | |||
| def __init__(self): | |||
| """Initialize Centralization""" | |||
| self.init_prim_io_names(inputs=['input_x', 'axis'], outputs=['output']) | |||
| def __infer__(self, input_x, axis): | |||
| x_shape = list(input_x['shape']) | |||
| x_dtype = input_x['dtype'] | |||
| axis_v = axis['value'] | |||
| rank = len(x_shape) | |||
| args = {'input_x': input_x['dtype']} | |||
| validator.check_tensors_dtypes_same_and_valid(args, [mstype.float16, mstype.float32], self.name) | |||
| if axis_v is None: | |||
| raise ValueError(f"For {self.name}, axis must be const.") | |||
| validator.check_value_type('axis', axis_v, [int, list, tuple], self.name) | |||
| if isinstance(axis_v, int): | |||
| validator.check_int_range(axis_v, -rank, rank, Rel.INC_LEFT, 'axis', self.name) | |||
| elif axis: | |||
| for index, one_axis in enumerate(axis_v): | |||
| validator.check_value_type('axis[%d]' % index, one_axis, [int], self.name) | |||
| out = {'shape': x_shape, | |||
| 'dtype': x_dtype, | |||
| 'value': None} | |||
| return out | |||
| @@ -18,12 +18,12 @@ import mindspore.context as context | |||
| import mindspore.nn as nn | |||
| from mindspore import Tensor | |||
| from mindspore.common.api import ms_function | |||
| from mindspore.ops import operations as P | |||
| from mindspore.ops.operations import _inner_ops as inner | |||
| class Net(nn.Cell): | |||
| def __init__(self, axis=()): | |||
| super(Net, self).__init__() | |||
| self.centralization = P.Centralization() | |||
| self.centralization = inner.Centralization() | |||
| self.axis = axis | |||
| @ms_function | |||