Merge pull request !1450 from zhaozhenlong/op/scatter-add-vmtags/v0.5.0-beta
| @@ -198,3 +198,4 @@ from .apply_rms_prop import _apply_rms_prop_tbe | |||||
| from .cumprod import _cumprop_tbe | from .cumprod import _cumprop_tbe | ||||
| from .reduce_prod import _reduce_prod_tbe | from .reduce_prod import _reduce_prod_tbe | ||||
| from .flatten_grad import _flatten_grad_tbe | from .flatten_grad import _flatten_grad_tbe | ||||
| from .scatter_add import _scatter_add_tbe | |||||
| @@ -0,0 +1,40 @@ | |||||
| # Copyright 2020 Huawei Technologies Co., Ltd | |||||
| # | |||||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| # you may not use this file except in compliance with the License. | |||||
| # You may obtain a copy of the License at | |||||
| # | |||||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||||
| # | |||||
| # Unless required by applicable law or agreed to in writing, software | |||||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| # See the License for the specific language governing permissions and | |||||
| # limitations under the License. | |||||
| # ============================================================================ | |||||
| """ScatterAdd op""" | |||||
| from mindspore.ops.op_info_register import op_info_register, TBERegOp, DataType | |||||
| scatter_add_op_info = TBERegOp("ScatterAdd") \ | |||||
| .fusion_type("ELEMWISE") \ | |||||
| .async_flag(False) \ | |||||
| .binfile_name("scatter_add.so") \ | |||||
| .compute_cost(10) \ | |||||
| .kernel_name("scatter_add") \ | |||||
| .partial_flag(True) \ | |||||
| .attr("use_locking", "optional", "bool", "all") \ | |||||
| .input(0, "var", False, "required", "all") \ | |||||
| .input(1, "indices", False, "required", "all") \ | |||||
| .input(2, "updates", False, "required", "all") \ | |||||
| .output(0, "var", False, "required", "all") \ | |||||
| .dtype_format(DataType.F16_Default, DataType.I32_Default, DataType.F16_Default, DataType.F16_Default) \ | |||||
| .dtype_format(DataType.F32_Default, DataType.I32_Default, DataType.F32_Default, DataType.F32_Default) \ | |||||
| .dtype_format(DataType.I32_Default, DataType.I32_Default, DataType.I32_Default, DataType.I32_Default) \ | |||||
| .get_op_info() | |||||
| @op_info_register(scatter_add_op_info) | |||||
| def _scatter_add_tbe(): | |||||
| """ScatterAdd TBE register""" | |||||
| return | |||||
| @@ -24,7 +24,7 @@ from .array_ops import (Argmax, Argmin, Cast, Concat, Pack, Unpack, | |||||
| Fill, GatherNd, GatherV2, InvertPermutation, | Fill, GatherNd, GatherV2, InvertPermutation, | ||||
| IsInstance, IsSubClass, ArgMaxWithValue, OnesLike, ZerosLike, | IsInstance, IsSubClass, ArgMaxWithValue, OnesLike, ZerosLike, | ||||
| Rank, Reshape, ResizeNearestNeighbor, ArgMinWithValue, Range, | Rank, Reshape, ResizeNearestNeighbor, ArgMinWithValue, Range, | ||||
| SameTypeShape, ScatterMax, ScatterUpdate, | |||||
| SameTypeShape, ScatterAdd, ScatterMax, ScatterUpdate, | |||||
| ScalarToArray, ScalarToTensor, ScatterNd, ScatterNdUpdate, Select, | ScalarToArray, ScalarToTensor, ScatterNd, ScatterNdUpdate, Select, | ||||
| Shape, Size, Slice, Split, | Shape, Size, Slice, Split, | ||||
| Squeeze, StridedSlice, Tile, | Squeeze, StridedSlice, Tile, | ||||
| @@ -190,6 +190,7 @@ __all__ = [ | |||||
| 'BoundingBoxEncode', | 'BoundingBoxEncode', | ||||
| 'BoundingBoxDecode', | 'BoundingBoxDecode', | ||||
| 'L2Normalize', | 'L2Normalize', | ||||
| 'ScatterAdd', | |||||
| 'ScatterNd', | 'ScatterNd', | ||||
| 'ScatterMax', | 'ScatterMax', | ||||
| 'ResizeNearestNeighbor', | 'ResizeNearestNeighbor', | ||||
| @@ -2145,6 +2145,12 @@ class ScatterNdUpdate(PrimitiveWithInfer): | |||||
| validator.check_tensor_type_same(args, (mstype.bool_,) + mstype.number_type, self.name) | validator.check_tensor_type_same(args, (mstype.bool_,) + mstype.number_type, self.name) | ||||
| return x_dtype | return x_dtype | ||||
| def _check_scatter_shape(x_shape, indices_shape, updates_shape, prim_name): | |||||
| if updates_shape and updates_shape != indices_shape + x_shape[1:]: | |||||
| raise ValueError(f"For '{prim_name}', the shape of updates should be [] or " | |||||
| f"updates_shape = indices_shape + x_shape[1:], but got x_shape: {x_shape}, " | |||||
| f"indices_shape: {indices_shape}, updates_shape: {updates_shape}.") | |||||
| class ScatterMax(PrimitiveWithInfer): | class ScatterMax(PrimitiveWithInfer): | ||||
| """ | """ | ||||
| @@ -2158,8 +2164,8 @@ class ScatterMax(PrimitiveWithInfer): | |||||
| Inputs: | Inputs: | ||||
| - **input_x** (Parameter) - The target parameter. | - **input_x** (Parameter) - The target parameter. | ||||
| - **indices** (Tensor) - The index to do max operation whose data type should be int. | - **indices** (Tensor) - The index to do max operation whose data type should be int. | ||||
| - **updates** (Tensor) - The tensor doing the maximum operation with 'input_x', | |||||
| the data type is same as 'input_x', the shape is 'indices_shape + x_shape[1:]'. | |||||
| - **updates** (Tensor) - The tensor doing the maximum operation with `input_x`, | |||||
| the data type is same as `input_x`, the shape is `indices_shape + x_shape[1:]`. | |||||
| Outputs: | Outputs: | ||||
| Tensor, has the same shape and data type as `input_x`. | Tensor, has the same shape and data type as `input_x`. | ||||
| @@ -2180,10 +2186,7 @@ class ScatterMax(PrimitiveWithInfer): | |||||
| validator.check_value_type('use_locking', use_locking, (bool,), self.name) | validator.check_value_type('use_locking', use_locking, (bool,), self.name) | ||||
| def infer_shape(self, x_shape, indices_shape, updates_shape): | def infer_shape(self, x_shape, indices_shape, updates_shape): | ||||
| if updates_shape and updates_shape != indices_shape + x_shape[1:]: | |||||
| raise ValueError(f"For '{self.name}', the shape of update should be [] or " | |||||
| f"update_shape = indices_shape + x_shape[1:], but got x_shape: {x_shape}, " | |||||
| f"indices_shape: {indices_shape}, update_shape: {updates_shape}.") | |||||
| _check_scatter_shape(x_shape, indices_shape, updates_shape, self.name) | |||||
| return x_shape | return x_shape | ||||
| def infer_dtype(self, x_dtype, indices_dtype, updates_dtype): | def infer_dtype(self, x_dtype, indices_dtype, updates_dtype): | ||||
| @@ -2193,6 +2196,49 @@ class ScatterMax(PrimitiveWithInfer): | |||||
| return x_dtype | return x_dtype | ||||
| class ScatterAdd(PrimitiveWithInfer): | |||||
| """ | |||||
| Update the value of the input tensor through the add operation. | |||||
| Using given values to update tensor value through the add operation, along with the input indices. | |||||
| Args: | |||||
| use_locking (bool): Whether protect the assignment by a lock. Default: False. | |||||
| Inputs: | |||||
| - **input_x** (Parameter) - The target parameter. | |||||
| - **indices** (Tensor) - The index to do add operation whose data type should be int. | |||||
| - **updates** (Tensor) - The tensor doing the add operation with `input_x`, | |||||
| the data type is same as `input_x`, the shape is `indices_shape + x_shape[1:]`. | |||||
| Outputs: | |||||
| Tensor, has the same shape and data type as `input_x`. | |||||
| Examples: | |||||
| >>> input_x = Parameter(Tensor(np.array([[0.0, 0.0, 0.0], [0.0, 0.0, 0.0]]), mindspore.float32), name="x") | |||||
| >>> indices = Tensor(np.array([[0, 1], [1, 1]]), mindspore.int32) | |||||
| >>> updates = Tensor(np.ones([2, 2, 3]), mindspore.float32) | |||||
| >>> scatter_add = P.ScatterAdd() | |||||
| >>> output = scatter_add(input_x, indices, updates) | |||||
| [[1.0, 1.0, 1.0], [3.0, 3.0, 3.0]] | |||||
| """ | |||||
| @prim_attr_register | |||||
| def __init__(self, use_locking=False): | |||||
| """Init ScatterAdd""" | |||||
| validator.check_value_type('use_locking', use_locking, (bool,), self.name) | |||||
| def infer_shape(self, x_shape, indices_shape, updates_shape): | |||||
| _check_scatter_shape(x_shape, indices_shape, updates_shape, self.name) | |||||
| return x_shape | |||||
| def infer_dtype(self, x_dtype, indices_dtype, updates_dtype): | |||||
| validator.check_tensor_type_same({'indices': indices_dtype}, mstype.int_type, self.name) | |||||
| args = {'x': x_dtype, 'updates': updates_dtype} | |||||
| validator.check_tensor_type_same(args, mstype.number_type, self.name) | |||||
| return x_dtype | |||||
| class SpaceToDepth(PrimitiveWithInfer): | class SpaceToDepth(PrimitiveWithInfer): | ||||
| r""" | r""" | ||||
| Rearrange blocks of spatial data into depth. | Rearrange blocks of spatial data into depth. | ||||
| @@ -196,6 +196,19 @@ class ScatterMax(nn.Cell): | |||||
| return out | return out | ||||
| class ScatterAdd(nn.Cell): | |||||
| """ScatterAdd net definition""" | |||||
| def __init__(self, ref_shape): | |||||
| super(ScatterAdd, self).__init__() | |||||
| self.scatter_add = P.ScatterAdd() | |||||
| self.ref = Parameter(Tensor(np.ones(ref_shape, np.float32)), name="ref") | |||||
| def construct(self, indices, updates): | |||||
| out = self.scatter_add(self.ref, indices, updates) | |||||
| return out | |||||
| class ApplyFtrlNet(nn.Cell): | class ApplyFtrlNet(nn.Cell): | ||||
| def __init__(self): | def __init__(self): | ||||
| super(ApplyFtrlNet, self).__init__() | super(ApplyFtrlNet, self).__init__() | ||||
| @@ -1257,6 +1270,17 @@ test_case_other_ops = [ | |||||
| 'desc_inputs': (Tensor(np.array([[0, 0], [1, 1]], np.int32)), | 'desc_inputs': (Tensor(np.array([[0, 0], [1, 1]], np.int32)), | ||||
| Tensor(np.ones([2, 2, 3], np.float32) * 99)), | Tensor(np.ones([2, 2, 3], np.float32) * 99)), | ||||
| 'skip': ['backward']}), | 'skip': ['backward']}), | ||||
| ('ScatterAdd', { | |||||
| 'block': ScatterAdd((6,)), | |||||
| 'desc_inputs': (Tensor(np.array([2, 0, 5], np.int32)), | |||||
| Tensor(np.array([2.0, 3.0, 4.0], np.float32))), | |||||
| 'skip': ['backward']}), | |||||
| ('ScatterAdd2d', { | |||||
| 'block': ScatterAdd((3, 4)), | |||||
| 'desc_inputs': (Tensor(np.array([[0, 1], [1, 2]], np.int32)), | |||||
| Tensor(np.array([[[1, 1, 1, 1], [2, 2, 2, 2]], | |||||
| [[3, 3, 3, 3], [4, 4, 4, 4]]], np.float32))), | |||||
| 'skip': ['backward']}), | |||||
| ('SmoothL1Loss', { | ('SmoothL1Loss', { | ||||
| 'block': P.SmoothL1Loss(), | 'block': P.SmoothL1Loss(), | ||||
| 'desc_inputs': [[256, 4], [256, 4]], | 'desc_inputs': [[256, 4], [256, 4]], | ||||