From: @david-he91 Reviewed-by: Signed-off-by:tags/v1.2.0-rc1
| @@ -38,6 +38,7 @@ ConstInputToAttrInfoRegistry::ConstInputToAttrInfoRegistry() { | |||||
| Register(prim::kPrimReduceMin->name(), {1}); | Register(prim::kPrimReduceMin->name(), {1}); | ||||
| Register(prim::kPrimReduceSum->name(), {1}); | Register(prim::kPrimReduceSum->name(), {1}); | ||||
| Register(prim::kPrimReduceMean->name(), {1}); | Register(prim::kPrimReduceMean->name(), {1}); | ||||
| Register(prim::kPrimCentralization->name(), {1}); | |||||
| Register(prim::kPrimGather->name(), {2}); | Register(prim::kPrimGather->name(), {2}); | ||||
| Register(prim::kPrimGatherD->name(), {1}); | Register(prim::kPrimGatherD->name(), {1}); | ||||
| Register(prim::kPrimEmbeddingLookup->name(), {2, 3, 4, 5}); | Register(prim::kPrimEmbeddingLookup->name(), {2, 3, 4, 5}); | ||||
| @@ -350,6 +350,7 @@ inline const PrimitivePtr kPrimReduceAll = std::make_shared<Primitive>("ReduceAl | |||||
| inline const PrimitivePtr kPrimReduceAny = std::make_shared<Primitive>("ReduceAny"); | inline const PrimitivePtr kPrimReduceAny = std::make_shared<Primitive>("ReduceAny"); | ||||
| inline const PrimitivePtr kPrimReduceMax = std::make_shared<Primitive>("ReduceMax"); | inline const PrimitivePtr kPrimReduceMax = std::make_shared<Primitive>("ReduceMax"); | ||||
| inline const PrimitivePtr kPrimReduceMin = std::make_shared<Primitive>("ReduceMin"); | inline const PrimitivePtr kPrimReduceMin = std::make_shared<Primitive>("ReduceMin"); | ||||
| inline const PrimitivePtr kPrimCentralization = std::make_shared<Primitive>("Centralization"); | |||||
| inline const PrimitivePtr kPrimNeg = std::make_shared<Primitive>("Neg"); | inline const PrimitivePtr kPrimNeg = std::make_shared<Primitive>("Neg"); | ||||
| inline const PrimitivePtr kPrimSin = std::make_shared<Primitive>("Sin"); | inline const PrimitivePtr kPrimSin = std::make_shared<Primitive>("Sin"); | ||||
| inline const PrimitivePtr kPrimCos = std::make_shared<Primitive>("Cos"); | inline const PrimitivePtr kPrimCos = std::make_shared<Primitive>("Cos"); | ||||
| @@ -360,3 +360,4 @@ from .nll_loss_grad import _nll_loss_grad_tbe | |||||
| from .mish import _mish_tbe | from .mish import _mish_tbe | ||||
| from .mul_no_nan import _mul_no_nan_tbe | from .mul_no_nan import _mul_no_nan_tbe | ||||
| from .selu import _selu_tbe | from .selu import _selu_tbe | ||||
| from .centralization import _centralization_tbe | |||||
| @@ -0,0 +1,38 @@ | |||||
| # Copyright 2021 Huawei Technologies Co., Ltd | |||||
| # | |||||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| # you may not use this file except in compliance with the License. | |||||
| # You may obtain a copy of the License at | |||||
| # | |||||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||||
| # | |||||
| # Unless required by applicable law or agreed to in writing, software | |||||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| # See the License for the specific language governing permissions and | |||||
| # limitations under the License. | |||||
| # ============================================================================ | |||||
| """Centralization op""" | |||||
| from mindspore.ops.op_info_register import op_info_register, TBERegOp, DataType | |||||
| centralization_op_info = TBERegOp("Centralization") \ | |||||
| .fusion_type("OPAQUE") \ | |||||
| .async_flag(False) \ | |||||
| .binfile_name("centralization.so") \ | |||||
| .compute_cost(10) \ | |||||
| .kernel_name("centralization") \ | |||||
| .partial_flag(True) \ | |||||
| .attr("axis", "required", "listInt", "all") \ | |||||
| .input(0, "x", False, "required", "all") \ | |||||
| .output(0, "y", False, "required", "all") \ | |||||
| .op_pattern("reduce") \ | |||||
| .dtype_format(DataType.F16_Default, DataType.F16_Default) \ | |||||
| .dtype_format(DataType.F32_Default, DataType.F32_Default) \ | |||||
| .get_op_info() | |||||
| @op_info_register(centralization_op_info) | |||||
| def _centralization_tbe(): | |||||
| """Centralization TBE register""" | |||||
| return | |||||
| @@ -41,7 +41,7 @@ from .comm_ops import (AllGather, AllReduce, _AlltoAll, AllSwap, ReduceScatter, | |||||
| from .debug_ops import (ImageSummary, InsertGradientOf, HookBackward, ScalarSummary, | from .debug_ops import (ImageSummary, InsertGradientOf, HookBackward, ScalarSummary, | ||||
| TensorSummary, HistogramSummary, Print, Assert) | TensorSummary, HistogramSummary, Print, Assert) | ||||
| from .control_ops import ControlDepend, GeSwitch, Merge | from .control_ops import ControlDepend, GeSwitch, Merge | ||||
| from .inner_ops import ScalarCast, Randperm, NoRepeatNGram, LambApplyOptimizerAssign, LambApplyWeightAssign, MakeRefKey | |||||
| from .inner_ops import ScalarCast, Randperm, NoRepeatNGram, LambApplyOptimizerAssign, LambApplyWeightAssign, MakeRefKey, Centralization | |||||
| from .math_ops import (Abs, ACos, Asin, Asinh, AddN, AccumulateNV2, AssignAdd, AssignSub, Atan2, BatchMatMul, | from .math_ops import (Abs, ACos, Asin, Asinh, AddN, AccumulateNV2, AssignAdd, AssignSub, Atan2, BatchMatMul, | ||||
| BitwiseAnd, BitwiseOr, | BitwiseAnd, BitwiseOr, | ||||
| @@ -21,6 +21,7 @@ from ..._checkparam import Rel | |||||
| from ...common import dtype as mstype | from ...common import dtype as mstype | ||||
| from ...common.dtype import tensor, dtype_to_pytype | from ...common.dtype import tensor, dtype_to_pytype | ||||
| from ..primitive import prim_attr_register, Primitive, PrimitiveWithInfer | from ..primitive import prim_attr_register, Primitive, PrimitiveWithInfer | ||||
| from .. import signature as sig | |||||
| class ScalarCast(PrimitiveWithInfer): | class ScalarCast(PrimitiveWithInfer): | ||||
| @@ -357,3 +358,70 @@ class MakeRefKey(Primitive): | |||||
| def __call__(self): | def __call__(self): | ||||
| pass | pass | ||||
| class Centralization(PrimitiveWithInfer): | |||||
| """ | |||||
| Computes centralization. y = x - mean(x, axis). | |||||
| Note: | |||||
| The dimension index starts at 0 and must be in the range `[-input.ndim, input.ndim)`. | |||||
| Inputs: | |||||
| - **input_x** (Tensor) - The input tensor. The data type mast be float16 or float32. | |||||
| - **axis** (Union[Int, Tuple(Int), List(Int)]) - The dimensions to reduce. Default: (), reduce all dimensions. | |||||
| Only constant value is allowed. Must be in the range [-rank(input_x), rank(input_x)). | |||||
| Outputs: | |||||
| Tensor, has the same shape and dtype as the `input_x`. | |||||
| Raises: | |||||
| TypeError: If `axis` is not one of the following types: int, list, tuple, NoneType. | |||||
| TypeError: If `axis` has non-Int elements. | |||||
| Supported Platforms: | |||||
| ``Ascend`` | |||||
| Examples: | |||||
| >>> mindspore.set_seed(1) | |||||
| >>> input_x = Tensor(np.random.randn(2, 2).astype(np.float32)) | |||||
| >>> centralization = ops.Centralization() | |||||
| >>> output = centralization(input_x, -1) | |||||
| >>> print(output) | |||||
| [[ 1.1180509 -1.1180508] | |||||
| [ 0.2723984 -0.2723984]] | |||||
| """ | |||||
| __mindspore_signature__ = ( | |||||
| sig.make_sig('input_x'), | |||||
| sig.make_sig('axis', default=()) | |||||
| ) | |||||
| @prim_attr_register | |||||
| def __init__(self): | |||||
| """Initialize Centralization""" | |||||
| self.init_prim_io_names(inputs=['input_x', 'axis'], outputs=['output']) | |||||
| def __infer__(self, input_x, axis): | |||||
| x_shape = list(input_x['shape']) | |||||
| x_dtype = input_x['dtype'] | |||||
| axis_v = axis['value'] | |||||
| rank = len(x_shape) | |||||
| args = {'input_x': input_x['dtype']} | |||||
| validator.check_tensors_dtypes_same_and_valid(args, [mstype.float16, mstype.float32], self.name) | |||||
| if axis_v is None: | |||||
| raise ValueError(f"For {self.name}, axis must be const.") | |||||
| validator.check_value_type('axis', axis_v, [int, list, tuple], self.name) | |||||
| if isinstance(axis_v, int): | |||||
| validator.check_int_range(axis_v, -rank, rank, Rel.INC_LEFT, 'axis', self.name) | |||||
| elif axis: | |||||
| for index, one_axis in enumerate(axis_v): | |||||
| validator.check_value_type('axis[%d]' % index, one_axis, [int], self.name) | |||||
| out = {'shape': x_shape, | |||||
| 'dtype': x_dtype, | |||||
| 'value': None} | |||||
| return out | |||||
| @@ -0,0 +1,47 @@ | |||||
| # Copyright 2021 Huawei Technologies Co., Ltd | |||||
| # | |||||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| # you may not use this file except in compliance with the License. | |||||
| # You may obtain a copy of the License at | |||||
| # | |||||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||||
| # | |||||
| # Unless required by applicable law or agreed to in writing, software | |||||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| # See the License for the specific language governing permissions and | |||||
| # limitations under the License. | |||||
| # ============================================================================ | |||||
| import numpy as np | |||||
| import mindspore.context as context | |||||
| import mindspore.nn as nn | |||||
| from mindspore import Tensor | |||||
| from mindspore.common.api import ms_function | |||||
| from mindspore.ops import operations as P | |||||
| class Net(nn.Cell): | |||||
| def __init__(self, axis=()): | |||||
| super(Net, self).__init__() | |||||
| self.centralization = P.Centralization() | |||||
| self.axis = axis | |||||
| @ms_function | |||||
| def construct(self, inputs): | |||||
| return self.centralization(inputs, self.axis) | |||||
| def test_net(): | |||||
| np.random.seed(1) | |||||
| x1 = np.random.randn(2, 2).astype(np.float32) | |||||
| context.set_context(mode=context.GRAPH_MODE, device_target="Ascend") | |||||
| centralization = Net(-1) | |||||
| output = centralization(Tensor(x1)) | |||||
| print(x1) | |||||
| print(output.asnumpy()) | |||||
| context.set_context(mode=context.PYNATIVE_MODE, device_target="Ascend") | |||||
| centralization = Net(-1) | |||||
| output = centralization(Tensor(x1)) | |||||
| print(x1) | |||||
| print(output.asnumpy()) | |||||