From: @david-he91 Reviewed-by: @liangchenghui,@guoqi1024 Signed-off-by: @liangchenghuitags/v1.2.0-rc1
| @@ -41,7 +41,7 @@ from .comm_ops import (AllGather, AllReduce, _AlltoAll, AllSwap, ReduceScatter, | |||||
| from .debug_ops import (ImageSummary, InsertGradientOf, HookBackward, ScalarSummary, | from .debug_ops import (ImageSummary, InsertGradientOf, HookBackward, ScalarSummary, | ||||
| TensorSummary, HistogramSummary, Print, Assert) | TensorSummary, HistogramSummary, Print, Assert) | ||||
| from .control_ops import ControlDepend, GeSwitch, Merge | from .control_ops import ControlDepend, GeSwitch, Merge | ||||
| from .inner_ops import ScalarCast, Randperm, NoRepeatNGram, LambApplyOptimizerAssign, LambApplyWeightAssign, MakeRefKey, Centralization | |||||
| from .inner_ops import ScalarCast, Randperm, NoRepeatNGram, LambApplyOptimizerAssign, LambApplyWeightAssign, MakeRefKey | |||||
| from .math_ops import (Abs, ACos, Asin, Asinh, AddN, AccumulateNV2, AssignAdd, AssignSub, Atan2, BatchMatMul, | from .math_ops import (Abs, ACos, Asin, Asinh, AddN, AccumulateNV2, AssignAdd, AssignSub, Atan2, BatchMatMul, | ||||
| BitwiseAnd, BitwiseOr, | BitwiseAnd, BitwiseOr, | ||||
| @@ -22,7 +22,7 @@ from ...common import dtype as mstype | |||||
| from ..primitive import PrimitiveWithCheck, PrimitiveWithInfer, prim_attr_register | from ..primitive import PrimitiveWithCheck, PrimitiveWithInfer, prim_attr_register | ||||
| from ..operations.math_ops import _infer_shape_reduce | from ..operations.math_ops import _infer_shape_reduce | ||||
| from ...communication.management import GlobalComm | from ...communication.management import GlobalComm | ||||
| from .. import signature as sig | |||||
| class ExtractImagePatches(PrimitiveWithInfer): | class ExtractImagePatches(PrimitiveWithInfer): | ||||
| """ | """ | ||||
| @@ -815,3 +815,70 @@ class SyncBatchNorm(PrimitiveWithInfer): | |||||
| args_moving = {"mean": mean, "variance": variance} | args_moving = {"mean": mean, "variance": variance} | ||||
| validator.check_tensors_dtypes_same_and_valid(args_moving, [mstype.float16, mstype.float32], self.name) | validator.check_tensors_dtypes_same_and_valid(args_moving, [mstype.float16, mstype.float32], self.name) | ||||
| return (input_x, scale, bias, input_x, input_x) | return (input_x, scale, bias, input_x, input_x) | ||||
| class Centralization(PrimitiveWithInfer): | |||||
| """ | |||||
| Computes centralization. y = x - mean(x, axis). | |||||
| Note: | |||||
| The dimension index starts at 0 and must be in the range `[-input.ndim, input.ndim)`. | |||||
| Inputs: | |||||
| - **input_x** (Tensor) - The input tensor. The data type mast be float16 or float32. | |||||
| - **axis** (Union[Int, Tuple(Int), List(Int)]) - The dimensions to reduce. Default: (), reduce all dimensions. | |||||
| Only constant value is allowed. Must be in the range [-rank(input_x), rank(input_x)). | |||||
| Outputs: | |||||
| Tensor, has the same shape and dtype as the `input_x`. | |||||
| Raises: | |||||
| TypeError: If `axis` is not one of the following types: int, list, tuple, NoneType. | |||||
| TypeError: If `axis` has non-Int elements. | |||||
| Supported Platforms: | |||||
| ``Ascend`` | |||||
| Examples: | |||||
| >>> mindspore.set_seed(1) | |||||
| >>> input_x = Tensor(np.random.randn(2, 2).astype(np.float32)) | |||||
| >>> centralization = ops.Centralization() | |||||
| >>> output = centralization(input_x, -1) | |||||
| >>> print(output) | |||||
| [[ 1.1180509 -1.1180508] | |||||
| [ 0.2723984 -0.2723984]] | |||||
| """ | |||||
| __mindspore_signature__ = ( | |||||
| sig.make_sig('input_x'), | |||||
| sig.make_sig('axis', default=()) | |||||
| ) | |||||
| @prim_attr_register | |||||
| def __init__(self): | |||||
| """Initialize Centralization""" | |||||
| self.init_prim_io_names(inputs=['input_x', 'axis'], outputs=['output']) | |||||
| def __infer__(self, input_x, axis): | |||||
| x_shape = list(input_x['shape']) | |||||
| x_dtype = input_x['dtype'] | |||||
| axis_v = axis['value'] | |||||
| rank = len(x_shape) | |||||
| args = {'input_x': input_x['dtype']} | |||||
| validator.check_tensors_dtypes_same_and_valid(args, [mstype.float16, mstype.float32], self.name) | |||||
| if axis_v is None: | |||||
| raise ValueError(f"For {self.name}, axis must be const.") | |||||
| validator.check_value_type('axis', axis_v, [int, list, tuple], self.name) | |||||
| if isinstance(axis_v, int): | |||||
| validator.check_int_range(axis_v, -rank, rank, Rel.INC_LEFT, 'axis', self.name) | |||||
| elif axis: | |||||
| for index, one_axis in enumerate(axis_v): | |||||
| validator.check_value_type('axis[%d]' % index, one_axis, [int], self.name) | |||||
| out = {'shape': x_shape, | |||||
| 'dtype': x_dtype, | |||||
| 'value': None} | |||||
| return out | |||||
| @@ -21,7 +21,6 @@ from ..._checkparam import Rel | |||||
| from ...common import dtype as mstype | from ...common import dtype as mstype | ||||
| from ...common.dtype import tensor, dtype_to_pytype | from ...common.dtype import tensor, dtype_to_pytype | ||||
| from ..primitive import prim_attr_register, Primitive, PrimitiveWithInfer | from ..primitive import prim_attr_register, Primitive, PrimitiveWithInfer | ||||
| from .. import signature as sig | |||||
| class ScalarCast(PrimitiveWithInfer): | class ScalarCast(PrimitiveWithInfer): | ||||
| @@ -358,70 +357,3 @@ class MakeRefKey(Primitive): | |||||
| def __call__(self): | def __call__(self): | ||||
| pass | pass | ||||
| class Centralization(PrimitiveWithInfer): | |||||
| """ | |||||
| Computes centralization. y = x - mean(x, axis). | |||||
| Note: | |||||
| The dimension index starts at 0 and must be in the range `[-input.ndim, input.ndim)`. | |||||
| Inputs: | |||||
| - **input_x** (Tensor) - The input tensor. The data type mast be float16 or float32. | |||||
| - **axis** (Union[Int, Tuple(Int), List(Int)]) - The dimensions to reduce. Default: (), reduce all dimensions. | |||||
| Only constant value is allowed. Must be in the range [-rank(input_x), rank(input_x)). | |||||
| Outputs: | |||||
| Tensor, has the same shape and dtype as the `input_x`. | |||||
| Raises: | |||||
| TypeError: If `axis` is not one of the following types: int, list, tuple, NoneType. | |||||
| TypeError: If `axis` has non-Int elements. | |||||
| Supported Platforms: | |||||
| ``Ascend`` | |||||
| Examples: | |||||
| >>> mindspore.set_seed(1) | |||||
| >>> input_x = Tensor(np.random.randn(2, 2).astype(np.float32)) | |||||
| >>> centralization = ops.Centralization() | |||||
| >>> output = centralization(input_x, -1) | |||||
| >>> print(output) | |||||
| [[ 1.1180509 -1.1180508] | |||||
| [ 0.2723984 -0.2723984]] | |||||
| """ | |||||
| __mindspore_signature__ = ( | |||||
| sig.make_sig('input_x'), | |||||
| sig.make_sig('axis', default=()) | |||||
| ) | |||||
| @prim_attr_register | |||||
| def __init__(self): | |||||
| """Initialize Centralization""" | |||||
| self.init_prim_io_names(inputs=['input_x', 'axis'], outputs=['output']) | |||||
| def __infer__(self, input_x, axis): | |||||
| x_shape = list(input_x['shape']) | |||||
| x_dtype = input_x['dtype'] | |||||
| axis_v = axis['value'] | |||||
| rank = len(x_shape) | |||||
| args = {'input_x': input_x['dtype']} | |||||
| validator.check_tensors_dtypes_same_and_valid(args, [mstype.float16, mstype.float32], self.name) | |||||
| if axis_v is None: | |||||
| raise ValueError(f"For {self.name}, axis must be const.") | |||||
| validator.check_value_type('axis', axis_v, [int, list, tuple], self.name) | |||||
| if isinstance(axis_v, int): | |||||
| validator.check_int_range(axis_v, -rank, rank, Rel.INC_LEFT, 'axis', self.name) | |||||
| elif axis: | |||||
| for index, one_axis in enumerate(axis_v): | |||||
| validator.check_value_type('axis[%d]' % index, one_axis, [int], self.name) | |||||
| out = {'shape': x_shape, | |||||
| 'dtype': x_dtype, | |||||
| 'value': None} | |||||
| return out | |||||
| @@ -18,12 +18,12 @@ import mindspore.context as context | |||||
| import mindspore.nn as nn | import mindspore.nn as nn | ||||
| from mindspore import Tensor | from mindspore import Tensor | ||||
| from mindspore.common.api import ms_function | from mindspore.common.api import ms_function | ||||
| from mindspore.ops import operations as P | |||||
| from mindspore.ops.operations import _inner_ops as inner | |||||
| class Net(nn.Cell): | class Net(nn.Cell): | ||||
| def __init__(self, axis=()): | def __init__(self, axis=()): | ||||
| super(Net, self).__init__() | super(Net, self).__init__() | ||||
| self.centralization = P.Centralization() | |||||
| self.centralization = inner.Centralization() | |||||
| self.axis = axis | self.axis = axis | ||||
| @ms_function | @ms_function | ||||