| @@ -38,6 +38,7 @@ ConstInputToAttrInfoRegistry::ConstInputToAttrInfoRegistry() { | |||
| Register(prim::kPrimReduceMin->name(), {1}); | |||
| Register(prim::kPrimReduceSum->name(), {1}); | |||
| Register(prim::kPrimReduceMean->name(), {1}); | |||
| Register(prim::kPrimCentralization->name(), {1}); | |||
| Register(prim::kPrimGather->name(), {2}); | |||
| Register(prim::kPrimGatherD->name(), {1}); | |||
| Register(prim::kPrimEmbeddingLookup->name(), {2, 3, 4, 5}); | |||
| @@ -350,6 +350,7 @@ inline const PrimitivePtr kPrimReduceAll = std::make_shared<Primitive>("ReduceAl | |||
| inline const PrimitivePtr kPrimReduceAny = std::make_shared<Primitive>("ReduceAny"); | |||
| inline const PrimitivePtr kPrimReduceMax = std::make_shared<Primitive>("ReduceMax"); | |||
| inline const PrimitivePtr kPrimReduceMin = std::make_shared<Primitive>("ReduceMin"); | |||
| inline const PrimitivePtr kPrimCentralization = std::make_shared<Primitive>("Centralization"); | |||
| inline const PrimitivePtr kPrimNeg = std::make_shared<Primitive>("Neg"); | |||
| inline const PrimitivePtr kPrimSin = std::make_shared<Primitive>("Sin"); | |||
| inline const PrimitivePtr kPrimCos = std::make_shared<Primitive>("Cos"); | |||
| @@ -360,3 +360,4 @@ from .nll_loss_grad import _nll_loss_grad_tbe | |||
| from .mish import _mish_tbe | |||
| from .mul_no_nan import _mul_no_nan_tbe | |||
| from .selu import _selu_tbe | |||
| from .centralization import _centralization_tbe | |||
| @@ -0,0 +1,38 @@ | |||
| # Copyright 2021 Huawei Technologies Co., Ltd | |||
| # | |||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||
| # you may not use this file except in compliance with the License. | |||
| # You may obtain a copy of the License at | |||
| # | |||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||
| # | |||
| # Unless required by applicable law or agreed to in writing, software | |||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| # See the License for the specific language governing permissions and | |||
| # limitations under the License. | |||
| # ============================================================================ | |||
| """Centralization op""" | |||
| from mindspore.ops.op_info_register import op_info_register, TBERegOp, DataType | |||
| centralization_op_info = TBERegOp("Centralization") \ | |||
| .fusion_type("OPAQUE") \ | |||
| .async_flag(False) \ | |||
| .binfile_name("centralization.so") \ | |||
| .compute_cost(10) \ | |||
| .kernel_name("centralization") \ | |||
| .partial_flag(True) \ | |||
| .attr("axis", "required", "listInt", "all") \ | |||
| .input(0, "x", False, "required", "all") \ | |||
| .output(0, "y", False, "required", "all") \ | |||
| .op_pattern("reduce") \ | |||
| .dtype_format(DataType.F16_Default, DataType.F16_Default) \ | |||
| .dtype_format(DataType.F32_Default, DataType.F32_Default) \ | |||
| .get_op_info() | |||
| @op_info_register(centralization_op_info) | |||
| def _centralization_tbe(): | |||
| """Centralization TBE register""" | |||
| return | |||
| @@ -41,7 +41,7 @@ from .comm_ops import (AllGather, AllReduce, _AlltoAll, AllSwap, ReduceScatter, | |||
| from .debug_ops import (ImageSummary, InsertGradientOf, HookBackward, ScalarSummary, | |||
| TensorSummary, HistogramSummary, Print, Assert) | |||
| from .control_ops import ControlDepend, GeSwitch, Merge | |||
| from .inner_ops import ScalarCast, Randperm, NoRepeatNGram, LambApplyOptimizerAssign, LambApplyWeightAssign, MakeRefKey | |||
| from .inner_ops import ScalarCast, Randperm, NoRepeatNGram, LambApplyOptimizerAssign, LambApplyWeightAssign, MakeRefKey, Centralization | |||
| from .math_ops import (Abs, ACos, Asin, Asinh, AddN, AccumulateNV2, AssignAdd, AssignSub, Atan2, BatchMatMul, | |||
| BitwiseAnd, BitwiseOr, | |||
| @@ -21,6 +21,7 @@ from ..._checkparam import Rel | |||
| from ...common import dtype as mstype | |||
| from ...common.dtype import tensor, dtype_to_pytype | |||
| from ..primitive import prim_attr_register, Primitive, PrimitiveWithInfer | |||
| from .. import signature as sig | |||
| class ScalarCast(PrimitiveWithInfer): | |||
| @@ -357,3 +358,70 @@ class MakeRefKey(Primitive): | |||
| def __call__(self): | |||
| pass | |||
| class Centralization(PrimitiveWithInfer): | |||
| """ | |||
| Computes centralization. y = x - mean(x, axis). | |||
| Note: | |||
| The dimension index starts at 0 and must be in the range `[-input.ndim, input.ndim)`. | |||
| Inputs: | |||
| - **input_x** (Tensor) - The input tensor. The data type mast be float16 or float32. | |||
| - **axis** (Union[Int, Tuple(Int), List(Int)]) - The dimensions to reduce. Default: (), reduce all dimensions. | |||
| Only constant value is allowed. Must be in the range [-rank(input_x), rank(input_x)). | |||
| Outputs: | |||
| Tensor, has the same shape and dtype as the `input_x`. | |||
| Raises: | |||
| TypeError: If `axis` is not one of the following types: int, list, tuple, NoneType. | |||
| TypeError: If `axis` has non-Int elements. | |||
| Supported Platforms: | |||
| ``Ascend`` | |||
| Examples: | |||
| >>> mindspore.set_seed(1) | |||
| >>> input_x = Tensor(np.random.randn(2, 2).astype(np.float32)) | |||
| >>> centralization = ops.Centralization() | |||
| >>> output = centralization(input_x, -1) | |||
| >>> print(output) | |||
| [[ 1.1180509 -1.1180508] | |||
| [ 0.2723984 -0.2723984]] | |||
| """ | |||
| __mindspore_signature__ = ( | |||
| sig.make_sig('input_x'), | |||
| sig.make_sig('axis', default=()) | |||
| ) | |||
| @prim_attr_register | |||
| def __init__(self): | |||
| """Initialize Centralization""" | |||
| self.init_prim_io_names(inputs=['input_x', 'axis'], outputs=['output']) | |||
| def __infer__(self, input_x, axis): | |||
| x_shape = list(input_x['shape']) | |||
| x_dtype = input_x['dtype'] | |||
| axis_v = axis['value'] | |||
| rank = len(x_shape) | |||
| args = {'input_x': input_x['dtype']} | |||
| validator.check_tensors_dtypes_same_and_valid(args, [mstype.float16, mstype.float32], self.name) | |||
| if axis_v is None: | |||
| raise ValueError(f"For {self.name}, axis must be const.") | |||
| validator.check_value_type('axis', axis_v, [int, list, tuple], self.name) | |||
| if isinstance(axis_v, int): | |||
| validator.check_int_range(axis_v, -rank, rank, Rel.INC_LEFT, 'axis', self.name) | |||
| elif axis: | |||
| for index, one_axis in enumerate(axis_v): | |||
| validator.check_value_type('axis[%d]' % index, one_axis, [int], self.name) | |||
| out = {'shape': x_shape, | |||
| 'dtype': x_dtype, | |||
| 'value': None} | |||
| return out | |||
| @@ -0,0 +1,47 @@ | |||
| # Copyright 2021 Huawei Technologies Co., Ltd | |||
| # | |||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||
| # you may not use this file except in compliance with the License. | |||
| # You may obtain a copy of the License at | |||
| # | |||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||
| # | |||
| # Unless required by applicable law or agreed to in writing, software | |||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| # See the License for the specific language governing permissions and | |||
| # limitations under the License. | |||
| # ============================================================================ | |||
| import numpy as np | |||
| import mindspore.context as context | |||
| import mindspore.nn as nn | |||
| from mindspore import Tensor | |||
| from mindspore.common.api import ms_function | |||
| from mindspore.ops import operations as P | |||
| class Net(nn.Cell): | |||
| def __init__(self, axis=()): | |||
| super(Net, self).__init__() | |||
| self.centralization = P.Centralization() | |||
| self.axis = axis | |||
| @ms_function | |||
| def construct(self, inputs): | |||
| return self.centralization(inputs, self.axis) | |||
| def test_net(): | |||
| np.random.seed(1) | |||
| x1 = np.random.randn(2, 2).astype(np.float32) | |||
| context.set_context(mode=context.GRAPH_MODE, device_target="Ascend") | |||
| centralization = Net(-1) | |||
| output = centralization(Tensor(x1)) | |||
| print(x1) | |||
| print(output.asnumpy()) | |||
| context.set_context(mode=context.PYNATIVE_MODE, device_target="Ascend") | |||
| centralization = Net(-1) | |||
| output = centralization(Tensor(x1)) | |||
| print(x1) | |||
| print(output.asnumpy()) | |||