Merge pull request !1973 from zhangbuxue/develop_TensorScatterUpdate_op_and_access_ge_and_vmtags/v0.5.0-beta
| @@ -103,6 +103,7 @@ const char kNameReLU6[] = "ReLU6"; | |||
| const char kNameReLU6Grad[] = "ReLU6Grad"; | |||
| const char kNameElu[] = "Elu"; | |||
| const char kNameEluGrad[] = "EluGrad"; | |||
| const char kNameTensorScatterUpdate[] = "TensorScatterUpdate"; | |||
| const char kNameScatterUpdate[] = "ScatterUpdate"; | |||
| const char kNameScatterNdUpdate[] = "ScatterNdUpdate"; | |||
| const char kNameScatterMax[] = "ScatterMax"; | |||
| @@ -261,6 +262,7 @@ std::unordered_map<std::string, OpAdapterDescPtr> &DfGraphConvertor::get_adpt_ma | |||
| {string(kNameResizeBilinear), ADPT_DESC(ResizeBilinearV2D)}, | |||
| {string(kNameZerosLike), ADPT_DESC(ZerosLike)}, | |||
| {string(kNameOnesLike), ADPT_DESC(OnesLike)}, | |||
| {string(kNameTensorScatterUpdate), ADPT_DESC(TensorScatterUpdate)}, | |||
| {string(kNameScatterUpdate), ADPT_DESC(ScatterUpdate)}, | |||
| {string(kNameScatterNdUpdate), ADPT_DESC(ScatterNdUpdate)}, | |||
| {string(kNameScatterMax), ADPT_DESC(ScatterMax)}, | |||
| @@ -525,6 +525,11 @@ INPUT_MAP(Unpack) = {{1, INPUT_DESC(x)}}; | |||
| ATTR_MAP(Unpack) = {{"axis", ATTR_DESC(axis, AnyTraits<int>())}, {"num", ATTR_DESC(num, AnyTraits<int>())}}; | |||
| DYN_OUTPUT_MAP(Unpack) = {{0, DYN_OUTPUT_DESC(y)}}; | |||
| // TensorScatterUpdate | |||
| INPUT_MAP(TensorScatterUpdate) = {{1, INPUT_DESC(x)}, {2, INPUT_DESC(indices)}, {3, INPUT_DESC(updates)}}; | |||
| ATTR_MAP(TensorScatterUpdate) = EMPTY_ATTR_MAP; | |||
| OUTPUT_MAP(TensorScatterUpdate) = {{0, OUTPUT_DESC(y)}}; | |||
| // ScatterUpdate | |||
| INPUT_MAP(ScatterUpdate) = {{1, INPUT_DESC(var)}, {2, INPUT_DESC(indices)}, {3, INPUT_DESC(updates)}}; | |||
| ATTR_MAP(ScatterUpdate) = {{"use_locking", ATTR_DESC(use_locking, AnyTraits<bool>())}}; | |||
| @@ -134,6 +134,8 @@ DECLARE_OP_ADAPTER(ZerosLike) | |||
| DECLARE_OP_USE_OUTPUT(ZerosLike) | |||
| DECLARE_OP_ADAPTER(OnesLike) | |||
| DECLARE_OP_USE_OUTPUT(OnesLike) | |||
| DECLARE_OP_ADAPTER(TensorScatterUpdate) | |||
| DECLARE_OP_USE_OUTPUT(TensorScatterUpdate) | |||
| DECLARE_OP_ADAPTER(ScatterUpdate) | |||
| DECLARE_OP_USE_OUTPUT(ScatterUpdate) | |||
| DECLARE_OP_ADAPTER(ScatterNdUpdate) | |||
| @@ -456,6 +456,20 @@ def get_bprop_scatter_nd_update(self): | |||
| return bprop | |||
| @bprop_getters.register(P.TensorScatterUpdate) | |||
| def get_bprop_tensor_scatter_update(self): | |||
| """Generate bprop for TensorScatterUpdate""" | |||
| gather_nd = P.GatherNd() | |||
| tensor_scatter_update = P.TensorScatterUpdate() | |||
| def bprop(x, indices, update, out, dout): | |||
| x_grad = tensor_scatter_update(dout, indices, zeros_like(update)) | |||
| update_grad = gather_nd(dout, indices) | |||
| return x_grad, zeros_like(indices), update_grad | |||
| return bprop | |||
| @bprop_getters.register(P.Argmax) | |||
| def get_bprop_argmax(self): | |||
| """Generate bprop for Argmax""" | |||
| @@ -255,3 +255,4 @@ from .lamb_next_right import _lamb_next_right_tbe | |||
| from .sparse_gather_v2 import _sparse_gather_v2_tbe | |||
| from .data_format_dim_map import _data_format_dim_map_tbe | |||
| from .histogram_fixed_width import _histogram_fixed_width_tbe | |||
| from .tensor_scatter_update import _tensor_scatter_update_tbe | |||
| @@ -31,7 +31,7 @@ scatter_nd_update_op_info = TBERegOp("ScatterNdUpdate") \ | |||
| .dtype_format(DataType.F16_Default, DataType.I32_Default, DataType.F16_Default, DataType.F16_Default) \ | |||
| .dtype_format(DataType.F32_Default, DataType.I32_Default, DataType.F32_Default, DataType.F32_Default) \ | |||
| .dtype_format(DataType.I8_Default, DataType.I32_Default, DataType.I8_Default, DataType.I8_Default) \ | |||
| .dtype_format(DataType.U8_Default, DataType.I32_Default, DataType.U8_Default, DataType.U8_Default,) \ | |||
| .dtype_format(DataType.U8_Default, DataType.I32_Default, DataType.U8_Default, DataType.U8_Default) \ | |||
| .dtype_format(DataType.BOOL_Default, DataType.I32_Default, DataType.BOOL_Default, DataType.BOOL_Default) \ | |||
| .get_op_info() | |||
| @@ -31,7 +31,7 @@ scatter_update_op_info = TBERegOp("ScatterUpdate") \ | |||
| .dtype_format(DataType.F16_Default, DataType.I32_Default, DataType.F16_Default, DataType.F16_Default) \ | |||
| .dtype_format(DataType.F32_Default, DataType.I32_Default, DataType.F32_Default, DataType.F32_Default) \ | |||
| .dtype_format(DataType.I8_Default, DataType.I32_Default, DataType.I8_Default, DataType.I8_Default) \ | |||
| .dtype_format(DataType.U8_Default, DataType.I32_Default, DataType.U8_Default, DataType.U8_Default,) \ | |||
| .dtype_format(DataType.U8_Default, DataType.I32_Default, DataType.U8_Default, DataType.U8_Default) \ | |||
| .dtype_format(DataType.BOOL_Default, DataType.I32_Default, DataType.BOOL_Default, DataType.BOOL_Default) \ | |||
| .get_op_info() | |||
| @@ -0,0 +1,41 @@ | |||
| # Copyright 2020 Huawei Technologies Co., Ltd | |||
| # | |||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||
| # you may not use this file except in compliance with the License. | |||
| # You may obtain a copy of the License at | |||
| # | |||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||
| # | |||
| # Unless required by applicable law or agreed to in writing, software | |||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| # See the License for the specific language governing permissions and | |||
| # limitations under the License. | |||
| # ============================================================================ | |||
| """TensorScatterUpdate op""" | |||
| from mindspore.ops.op_info_register import op_info_register, TBERegOp, DataType | |||
| tensor_scatter_update_op_info = TBERegOp("TensorScatterUpdate") \ | |||
| .fusion_type("ELEMWISE") \ | |||
| .async_flag(False) \ | |||
| .binfile_name("tensor_scatter_update.so") \ | |||
| .compute_cost(10) \ | |||
| .kernel_name("tensor_scatter_update") \ | |||
| .partial_flag(True) \ | |||
| .input(0, "x", False, "required", "all") \ | |||
| .input(1, "indices", False, "required", "all") \ | |||
| .input(1, "updates", False, "required", "all") \ | |||
| .output(0, "y", False, "required", "all") \ | |||
| .dtype_format(DataType.F16_Default, DataType.I32_Default, DataType.F16_Default, DataType.F16_Default) \ | |||
| .dtype_format(DataType.F32_Default, DataType.I32_Default, DataType.F32_Default, DataType.F32_Default) \ | |||
| .dtype_format(DataType.I8_Default, DataType.I32_Default, DataType.I8_Default, DataType.I8_Default) \ | |||
| .dtype_format(DataType.U8_Default, DataType.I32_Default, DataType.U8_Default, DataType.U8_Default) \ | |||
| .dtype_format(DataType.I32_Default, DataType.I32_Default, DataType.I32_Default, DataType.I32_Default) \ | |||
| .get_op_info() | |||
| @op_info_register(tensor_scatter_update_op_info) | |||
| def _tensor_scatter_update_tbe(): | |||
| """TensorScatterUpdate TBE register""" | |||
| return | |||
| @@ -27,7 +27,7 @@ from .array_ops import (Argmax, Argmin, Cast, Concat, Pack, Unpack, | |||
| SameTypeShape, ScatterAdd, ScatterMax, ScatterUpdate, | |||
| ScalarToArray, ScalarToTensor, ScatterNd, ScatterNdUpdate, Select, | |||
| Shape, Size, Slice, Split, EmbeddingLookup, | |||
| Squeeze, StridedSlice, Tile, | |||
| Squeeze, StridedSlice, Tile, TensorScatterUpdate, | |||
| Transpose, TruncatedNormal, TupleToArray, UnsortedSegmentMin, | |||
| UnsortedSegmentSum, SpaceToDepth, DepthToSpace, SpaceToBatch, BatchToSpace, | |||
| SpaceToBatchND, BatchToSpaceND, BroadcastTo) | |||
| @@ -212,6 +212,7 @@ __all__ = [ | |||
| 'Pad', | |||
| 'MirrorPad', | |||
| 'GatherNd', | |||
| 'TensorScatterUpdate', | |||
| 'ScatterUpdate', | |||
| 'ScatterNdUpdate', | |||
| 'Floor', | |||
| @@ -2187,6 +2187,47 @@ class GatherNd(PrimitiveWithInfer): | |||
| return x_dtype | |||
| class TensorScatterUpdate(PrimitiveWithInfer): | |||
| """ | |||
| Update tensor value by using input indices and value. | |||
| Using given values to update tensor value, along with the input indices. | |||
| Inputs: | |||
| - **input_x** (Tensor) - The target tensor. | |||
| - **indices** (Tensor) - The index of input tensor whose data type is int32. | |||
| - **update** (Tensor) - The tensor to update the input tensor, has the same type as input, | |||
| and update.shape = indices.shape + input_x.shape[1:]. | |||
| Outputs: | |||
| Tensor, has the same shape and type as `input_x`. | |||
| Examples: | |||
| >>> input_x = Tensor(np.array([[-0.1, 0.3, 3.6], [0.4, 0.5, -3.2]]), mindspore.float32) | |||
| >>> indices = Tensor(np.array([[0, 0], [1, 1]]), mindspore.int32) | |||
| >>> update = Tensor(np.array([1.0, 2.2]), mindspore.float32) | |||
| >>> op = P.TensorScatterUpdate() | |||
| >>> output = op(input_x, indices, update) | |||
| """ | |||
| @prim_attr_register | |||
| def __init__(self): | |||
| """Init TensorScatterUpdate""" | |||
| self.init_prim_io_names(inputs=['x', 'indices', 'value'], outputs=['y']) | |||
| def infer_shape(self, x_shape, indices_shape, value_shape): | |||
| validator.check('the dimension of x', len(x_shape), | |||
| 'the dimension of indices', indices_shape[-1], Rel.GE) | |||
| if indices_shape[:-1] + x_shape[indices_shape[-1]:] != value_shape: | |||
| raise ValueError("For 'TensorScatterUpdate', input value are not match with input indices.") | |||
| return x_shape | |||
| def infer_dtype(self, x_dtype, indices_dtype, value_dtype): | |||
| validator.check_tensor_type_same({'indices': indices_dtype}, [mstype.int32], self.name) | |||
| args = {"x": x_dtype, "value": value_dtype} | |||
| validator.check_tensor_type_same(args, (mstype.bool_,) + mstype.number_type, self.name) | |||
| return x_dtype | |||
| class ScatterUpdate(PrimitiveWithInfer): | |||
| """ | |||
| Update tensor value by using input indices and value. | |||
| @@ -2227,7 +2268,7 @@ class ScatterUpdate(PrimitiveWithInfer): | |||
| def infer_shape(self, x_shape, indices_shape, value_shape): | |||
| if indices_shape + x_shape[1:] != value_shape: | |||
| raise ValueError('Input value are not match with input indices.') | |||
| raise ValueError("For 'ScatterUpdate', input value are not match with input indices.") | |||
| return x_shape | |||
| def infer_dtype(self, x_dtype, indices_dtype, value_dtype): | |||
| @@ -2277,7 +2318,7 @@ class ScatterNdUpdate(PrimitiveWithInfer): | |||
| validator.check('the dimension of x', len(x_shape), | |||
| 'the dimension of indices', indices_shape[-1], Rel.GE) | |||
| if indices_shape[:-1] + x_shape[indices_shape[-1]:] != value_shape: | |||
| raise ValueError('Input value are not match with input indices.') | |||
| raise ValueError("For 'ScatterNdUpdate', input value are not match with input indices.") | |||
| return x_shape | |||
| def infer_dtype(self, x_dtype, indices_dtype, value_dtype): | |||
| @@ -34,6 +34,25 @@ from ....mindspore_test_framework.pipeline.gradient.compile_gradient \ | |||
| import pipeline_for_compile_grad_ge_graph_for_case_by_case_config | |||
| def test_tensor_scatter_update(): | |||
| class TensorScatterUpdateNet(nn.Cell): | |||
| """TensorScatterUpdate net definition""" | |||
| def __init__(self): | |||
| super(TensorScatterUpdateNet, self).__init__() | |||
| self.tensor_scatter_update = P.TensorScatterUpdate() | |||
| def construct(self, x, i, u): | |||
| out = self.tensor_scatter_update(x, i, u) | |||
| return out | |||
| net = TensorScatterUpdateNet() | |||
| context.set_context(mode=context.GRAPH_MODE, save_graphs=True) | |||
| x = Tensor(np.arange(3 * 4 * 5).reshape((3, 4, 5)), mstype.float32) | |||
| indices = Tensor(np.array([[0, 0], [1, 1]], np.int32)) | |||
| updates = Tensor(np.ones([2, 5], np.float32)) | |||
| net(x, indices, updates) | |||
| class InputBackward(nn.Cell): | |||
| def __init__(self, network): | |||
| super(InputBackward, self).__init__() | |||
| @@ -1537,6 +1556,12 @@ test_case_other_ops = [ | |||
| 'desc_inputs': (Tensor(np.ones((2, 2), np.int32)), | |||
| Tensor(np.ones((2,), np.int32))), | |||
| 'desc_bprop': [([3, 3], {'dtype': np.int32})]}), | |||
| ('TensorScatterUpdate', { | |||
| 'block': P.TensorScatterUpdate(), | |||
| 'desc_inputs': (Tensor(np.arange(3 * 4 * 5).reshape((3, 4, 5)), mstype.float32), | |||
| Tensor(np.array([[0, 1], [1, 2]], np.int32)), | |||
| Tensor(np.ones([2, 5], np.float32) * 99)), | |||
| 'desc_bprop': [([3, 4, 5], {'dtype': np.float32})]}), | |||
| ('ScatterMax', { | |||
| 'block': ScatterMax(), | |||
| 'desc_inputs': (Tensor(np.array([[0, 0], [1, 1]], np.int32)), | |||