Merge pull request !1973 from zhangbuxue/develop_TensorScatterUpdate_op_and_access_ge_and_vmtags/v0.5.0-beta
| @@ -103,6 +103,7 @@ const char kNameReLU6[] = "ReLU6"; | |||||
| const char kNameReLU6Grad[] = "ReLU6Grad"; | const char kNameReLU6Grad[] = "ReLU6Grad"; | ||||
| const char kNameElu[] = "Elu"; | const char kNameElu[] = "Elu"; | ||||
| const char kNameEluGrad[] = "EluGrad"; | const char kNameEluGrad[] = "EluGrad"; | ||||
| const char kNameTensorScatterUpdate[] = "TensorScatterUpdate"; | |||||
| const char kNameScatterUpdate[] = "ScatterUpdate"; | const char kNameScatterUpdate[] = "ScatterUpdate"; | ||||
| const char kNameScatterNdUpdate[] = "ScatterNdUpdate"; | const char kNameScatterNdUpdate[] = "ScatterNdUpdate"; | ||||
| const char kNameScatterMax[] = "ScatterMax"; | const char kNameScatterMax[] = "ScatterMax"; | ||||
| @@ -261,6 +262,7 @@ std::unordered_map<std::string, OpAdapterDescPtr> &DfGraphConvertor::get_adpt_ma | |||||
| {string(kNameResizeBilinear), ADPT_DESC(ResizeBilinearV2D)}, | {string(kNameResizeBilinear), ADPT_DESC(ResizeBilinearV2D)}, | ||||
| {string(kNameZerosLike), ADPT_DESC(ZerosLike)}, | {string(kNameZerosLike), ADPT_DESC(ZerosLike)}, | ||||
| {string(kNameOnesLike), ADPT_DESC(OnesLike)}, | {string(kNameOnesLike), ADPT_DESC(OnesLike)}, | ||||
| {string(kNameTensorScatterUpdate), ADPT_DESC(TensorScatterUpdate)}, | |||||
| {string(kNameScatterUpdate), ADPT_DESC(ScatterUpdate)}, | {string(kNameScatterUpdate), ADPT_DESC(ScatterUpdate)}, | ||||
| {string(kNameScatterNdUpdate), ADPT_DESC(ScatterNdUpdate)}, | {string(kNameScatterNdUpdate), ADPT_DESC(ScatterNdUpdate)}, | ||||
| {string(kNameScatterMax), ADPT_DESC(ScatterMax)}, | {string(kNameScatterMax), ADPT_DESC(ScatterMax)}, | ||||
| @@ -525,6 +525,11 @@ INPUT_MAP(Unpack) = {{1, INPUT_DESC(x)}}; | |||||
| ATTR_MAP(Unpack) = {{"axis", ATTR_DESC(axis, AnyTraits<int>())}, {"num", ATTR_DESC(num, AnyTraits<int>())}}; | ATTR_MAP(Unpack) = {{"axis", ATTR_DESC(axis, AnyTraits<int>())}, {"num", ATTR_DESC(num, AnyTraits<int>())}}; | ||||
| DYN_OUTPUT_MAP(Unpack) = {{0, DYN_OUTPUT_DESC(y)}}; | DYN_OUTPUT_MAP(Unpack) = {{0, DYN_OUTPUT_DESC(y)}}; | ||||
| // TensorScatterUpdate | |||||
| INPUT_MAP(TensorScatterUpdate) = {{1, INPUT_DESC(x)}, {2, INPUT_DESC(indices)}, {3, INPUT_DESC(updates)}}; | |||||
| ATTR_MAP(TensorScatterUpdate) = EMPTY_ATTR_MAP; | |||||
| OUTPUT_MAP(TensorScatterUpdate) = {{0, OUTPUT_DESC(y)}}; | |||||
| // ScatterUpdate | // ScatterUpdate | ||||
| INPUT_MAP(ScatterUpdate) = {{1, INPUT_DESC(var)}, {2, INPUT_DESC(indices)}, {3, INPUT_DESC(updates)}}; | INPUT_MAP(ScatterUpdate) = {{1, INPUT_DESC(var)}, {2, INPUT_DESC(indices)}, {3, INPUT_DESC(updates)}}; | ||||
| ATTR_MAP(ScatterUpdate) = {{"use_locking", ATTR_DESC(use_locking, AnyTraits<bool>())}}; | ATTR_MAP(ScatterUpdate) = {{"use_locking", ATTR_DESC(use_locking, AnyTraits<bool>())}}; | ||||
| @@ -134,6 +134,8 @@ DECLARE_OP_ADAPTER(ZerosLike) | |||||
| DECLARE_OP_USE_OUTPUT(ZerosLike) | DECLARE_OP_USE_OUTPUT(ZerosLike) | ||||
| DECLARE_OP_ADAPTER(OnesLike) | DECLARE_OP_ADAPTER(OnesLike) | ||||
| DECLARE_OP_USE_OUTPUT(OnesLike) | DECLARE_OP_USE_OUTPUT(OnesLike) | ||||
| DECLARE_OP_ADAPTER(TensorScatterUpdate) | |||||
| DECLARE_OP_USE_OUTPUT(TensorScatterUpdate) | |||||
| DECLARE_OP_ADAPTER(ScatterUpdate) | DECLARE_OP_ADAPTER(ScatterUpdate) | ||||
| DECLARE_OP_USE_OUTPUT(ScatterUpdate) | DECLARE_OP_USE_OUTPUT(ScatterUpdate) | ||||
| DECLARE_OP_ADAPTER(ScatterNdUpdate) | DECLARE_OP_ADAPTER(ScatterNdUpdate) | ||||
| @@ -456,6 +456,20 @@ def get_bprop_scatter_nd_update(self): | |||||
| return bprop | return bprop | ||||
| @bprop_getters.register(P.TensorScatterUpdate) | |||||
| def get_bprop_tensor_scatter_update(self): | |||||
| """Generate bprop for TensorScatterUpdate""" | |||||
| gather_nd = P.GatherNd() | |||||
| tensor_scatter_update = P.TensorScatterUpdate() | |||||
| def bprop(x, indices, update, out, dout): | |||||
| x_grad = tensor_scatter_update(dout, indices, zeros_like(update)) | |||||
| update_grad = gather_nd(dout, indices) | |||||
| return x_grad, zeros_like(indices), update_grad | |||||
| return bprop | |||||
| @bprop_getters.register(P.Argmax) | @bprop_getters.register(P.Argmax) | ||||
| def get_bprop_argmax(self): | def get_bprop_argmax(self): | ||||
| """Generate bprop for Argmax""" | """Generate bprop for Argmax""" | ||||
| @@ -255,3 +255,4 @@ from .lamb_next_right import _lamb_next_right_tbe | |||||
| from .sparse_gather_v2 import _sparse_gather_v2_tbe | from .sparse_gather_v2 import _sparse_gather_v2_tbe | ||||
| from .data_format_dim_map import _data_format_dim_map_tbe | from .data_format_dim_map import _data_format_dim_map_tbe | ||||
| from .histogram_fixed_width import _histogram_fixed_width_tbe | from .histogram_fixed_width import _histogram_fixed_width_tbe | ||||
| from .tensor_scatter_update import _tensor_scatter_update_tbe | |||||
| @@ -31,7 +31,7 @@ scatter_nd_update_op_info = TBERegOp("ScatterNdUpdate") \ | |||||
| .dtype_format(DataType.F16_Default, DataType.I32_Default, DataType.F16_Default, DataType.F16_Default) \ | .dtype_format(DataType.F16_Default, DataType.I32_Default, DataType.F16_Default, DataType.F16_Default) \ | ||||
| .dtype_format(DataType.F32_Default, DataType.I32_Default, DataType.F32_Default, DataType.F32_Default) \ | .dtype_format(DataType.F32_Default, DataType.I32_Default, DataType.F32_Default, DataType.F32_Default) \ | ||||
| .dtype_format(DataType.I8_Default, DataType.I32_Default, DataType.I8_Default, DataType.I8_Default) \ | .dtype_format(DataType.I8_Default, DataType.I32_Default, DataType.I8_Default, DataType.I8_Default) \ | ||||
| .dtype_format(DataType.U8_Default, DataType.I32_Default, DataType.U8_Default, DataType.U8_Default,) \ | |||||
| .dtype_format(DataType.U8_Default, DataType.I32_Default, DataType.U8_Default, DataType.U8_Default) \ | |||||
| .dtype_format(DataType.BOOL_Default, DataType.I32_Default, DataType.BOOL_Default, DataType.BOOL_Default) \ | .dtype_format(DataType.BOOL_Default, DataType.I32_Default, DataType.BOOL_Default, DataType.BOOL_Default) \ | ||||
| .get_op_info() | .get_op_info() | ||||
| @@ -31,7 +31,7 @@ scatter_update_op_info = TBERegOp("ScatterUpdate") \ | |||||
| .dtype_format(DataType.F16_Default, DataType.I32_Default, DataType.F16_Default, DataType.F16_Default) \ | .dtype_format(DataType.F16_Default, DataType.I32_Default, DataType.F16_Default, DataType.F16_Default) \ | ||||
| .dtype_format(DataType.F32_Default, DataType.I32_Default, DataType.F32_Default, DataType.F32_Default) \ | .dtype_format(DataType.F32_Default, DataType.I32_Default, DataType.F32_Default, DataType.F32_Default) \ | ||||
| .dtype_format(DataType.I8_Default, DataType.I32_Default, DataType.I8_Default, DataType.I8_Default) \ | .dtype_format(DataType.I8_Default, DataType.I32_Default, DataType.I8_Default, DataType.I8_Default) \ | ||||
| .dtype_format(DataType.U8_Default, DataType.I32_Default, DataType.U8_Default, DataType.U8_Default,) \ | |||||
| .dtype_format(DataType.U8_Default, DataType.I32_Default, DataType.U8_Default, DataType.U8_Default) \ | |||||
| .dtype_format(DataType.BOOL_Default, DataType.I32_Default, DataType.BOOL_Default, DataType.BOOL_Default) \ | .dtype_format(DataType.BOOL_Default, DataType.I32_Default, DataType.BOOL_Default, DataType.BOOL_Default) \ | ||||
| .get_op_info() | .get_op_info() | ||||
| @@ -0,0 +1,41 @@ | |||||
| # Copyright 2020 Huawei Technologies Co., Ltd | |||||
| # | |||||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| # you may not use this file except in compliance with the License. | |||||
| # You may obtain a copy of the License at | |||||
| # | |||||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||||
| # | |||||
| # Unless required by applicable law or agreed to in writing, software | |||||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| # See the License for the specific language governing permissions and | |||||
| # limitations under the License. | |||||
| # ============================================================================ | |||||
| """TensorScatterUpdate op""" | |||||
| from mindspore.ops.op_info_register import op_info_register, TBERegOp, DataType | |||||
| tensor_scatter_update_op_info = TBERegOp("TensorScatterUpdate") \ | |||||
| .fusion_type("ELEMWISE") \ | |||||
| .async_flag(False) \ | |||||
| .binfile_name("tensor_scatter_update.so") \ | |||||
| .compute_cost(10) \ | |||||
| .kernel_name("tensor_scatter_update") \ | |||||
| .partial_flag(True) \ | |||||
| .input(0, "x", False, "required", "all") \ | |||||
| .input(1, "indices", False, "required", "all") \ | |||||
| .input(1, "updates", False, "required", "all") \ | |||||
| .output(0, "y", False, "required", "all") \ | |||||
| .dtype_format(DataType.F16_Default, DataType.I32_Default, DataType.F16_Default, DataType.F16_Default) \ | |||||
| .dtype_format(DataType.F32_Default, DataType.I32_Default, DataType.F32_Default, DataType.F32_Default) \ | |||||
| .dtype_format(DataType.I8_Default, DataType.I32_Default, DataType.I8_Default, DataType.I8_Default) \ | |||||
| .dtype_format(DataType.U8_Default, DataType.I32_Default, DataType.U8_Default, DataType.U8_Default) \ | |||||
| .dtype_format(DataType.I32_Default, DataType.I32_Default, DataType.I32_Default, DataType.I32_Default) \ | |||||
| .get_op_info() | |||||
| @op_info_register(tensor_scatter_update_op_info) | |||||
| def _tensor_scatter_update_tbe(): | |||||
| """TensorScatterUpdate TBE register""" | |||||
| return | |||||
| @@ -27,7 +27,7 @@ from .array_ops import (Argmax, Argmin, Cast, Concat, Pack, Unpack, | |||||
| SameTypeShape, ScatterAdd, ScatterMax, ScatterUpdate, | SameTypeShape, ScatterAdd, ScatterMax, ScatterUpdate, | ||||
| ScalarToArray, ScalarToTensor, ScatterNd, ScatterNdUpdate, Select, | ScalarToArray, ScalarToTensor, ScatterNd, ScatterNdUpdate, Select, | ||||
| Shape, Size, Slice, Split, EmbeddingLookup, | Shape, Size, Slice, Split, EmbeddingLookup, | ||||
| Squeeze, StridedSlice, Tile, | |||||
| Squeeze, StridedSlice, Tile, TensorScatterUpdate, | |||||
| Transpose, TruncatedNormal, TupleToArray, UnsortedSegmentMin, | Transpose, TruncatedNormal, TupleToArray, UnsortedSegmentMin, | ||||
| UnsortedSegmentSum, SpaceToDepth, DepthToSpace, SpaceToBatch, BatchToSpace, | UnsortedSegmentSum, SpaceToDepth, DepthToSpace, SpaceToBatch, BatchToSpace, | ||||
| SpaceToBatchND, BatchToSpaceND, BroadcastTo) | SpaceToBatchND, BatchToSpaceND, BroadcastTo) | ||||
| @@ -212,6 +212,7 @@ __all__ = [ | |||||
| 'Pad', | 'Pad', | ||||
| 'MirrorPad', | 'MirrorPad', | ||||
| 'GatherNd', | 'GatherNd', | ||||
| 'TensorScatterUpdate', | |||||
| 'ScatterUpdate', | 'ScatterUpdate', | ||||
| 'ScatterNdUpdate', | 'ScatterNdUpdate', | ||||
| 'Floor', | 'Floor', | ||||
| @@ -2187,6 +2187,47 @@ class GatherNd(PrimitiveWithInfer): | |||||
| return x_dtype | return x_dtype | ||||
| class TensorScatterUpdate(PrimitiveWithInfer): | |||||
| """ | |||||
| Update tensor value by using input indices and value. | |||||
| Using given values to update tensor value, along with the input indices. | |||||
| Inputs: | |||||
| - **input_x** (Tensor) - The target tensor. | |||||
| - **indices** (Tensor) - The index of input tensor whose data type is int32. | |||||
| - **update** (Tensor) - The tensor to update the input tensor, has the same type as input, | |||||
| and update.shape = indices.shape + input_x.shape[1:]. | |||||
| Outputs: | |||||
| Tensor, has the same shape and type as `input_x`. | |||||
| Examples: | |||||
| >>> input_x = Tensor(np.array([[-0.1, 0.3, 3.6], [0.4, 0.5, -3.2]]), mindspore.float32) | |||||
| >>> indices = Tensor(np.array([[0, 0], [1, 1]]), mindspore.int32) | |||||
| >>> update = Tensor(np.array([1.0, 2.2]), mindspore.float32) | |||||
| >>> op = P.TensorScatterUpdate() | |||||
| >>> output = op(input_x, indices, update) | |||||
| """ | |||||
| @prim_attr_register | |||||
| def __init__(self): | |||||
| """Init TensorScatterUpdate""" | |||||
| self.init_prim_io_names(inputs=['x', 'indices', 'value'], outputs=['y']) | |||||
| def infer_shape(self, x_shape, indices_shape, value_shape): | |||||
| validator.check('the dimension of x', len(x_shape), | |||||
| 'the dimension of indices', indices_shape[-1], Rel.GE) | |||||
| if indices_shape[:-1] + x_shape[indices_shape[-1]:] != value_shape: | |||||
| raise ValueError("For 'TensorScatterUpdate', input value are not match with input indices.") | |||||
| return x_shape | |||||
| def infer_dtype(self, x_dtype, indices_dtype, value_dtype): | |||||
| validator.check_tensor_type_same({'indices': indices_dtype}, [mstype.int32], self.name) | |||||
| args = {"x": x_dtype, "value": value_dtype} | |||||
| validator.check_tensor_type_same(args, (mstype.bool_,) + mstype.number_type, self.name) | |||||
| return x_dtype | |||||
| class ScatterUpdate(PrimitiveWithInfer): | class ScatterUpdate(PrimitiveWithInfer): | ||||
| """ | """ | ||||
| Update tensor value by using input indices and value. | Update tensor value by using input indices and value. | ||||
| @@ -2227,7 +2268,7 @@ class ScatterUpdate(PrimitiveWithInfer): | |||||
| def infer_shape(self, x_shape, indices_shape, value_shape): | def infer_shape(self, x_shape, indices_shape, value_shape): | ||||
| if indices_shape + x_shape[1:] != value_shape: | if indices_shape + x_shape[1:] != value_shape: | ||||
| raise ValueError('Input value are not match with input indices.') | |||||
| raise ValueError("For 'ScatterUpdate', input value are not match with input indices.") | |||||
| return x_shape | return x_shape | ||||
| def infer_dtype(self, x_dtype, indices_dtype, value_dtype): | def infer_dtype(self, x_dtype, indices_dtype, value_dtype): | ||||
| @@ -2277,7 +2318,7 @@ class ScatterNdUpdate(PrimitiveWithInfer): | |||||
| validator.check('the dimension of x', len(x_shape), | validator.check('the dimension of x', len(x_shape), | ||||
| 'the dimension of indices', indices_shape[-1], Rel.GE) | 'the dimension of indices', indices_shape[-1], Rel.GE) | ||||
| if indices_shape[:-1] + x_shape[indices_shape[-1]:] != value_shape: | if indices_shape[:-1] + x_shape[indices_shape[-1]:] != value_shape: | ||||
| raise ValueError('Input value are not match with input indices.') | |||||
| raise ValueError("For 'ScatterNdUpdate', input value are not match with input indices.") | |||||
| return x_shape | return x_shape | ||||
| def infer_dtype(self, x_dtype, indices_dtype, value_dtype): | def infer_dtype(self, x_dtype, indices_dtype, value_dtype): | ||||
| @@ -34,6 +34,25 @@ from ....mindspore_test_framework.pipeline.gradient.compile_gradient \ | |||||
| import pipeline_for_compile_grad_ge_graph_for_case_by_case_config | import pipeline_for_compile_grad_ge_graph_for_case_by_case_config | ||||
| def test_tensor_scatter_update(): | |||||
| class TensorScatterUpdateNet(nn.Cell): | |||||
| """TensorScatterUpdate net definition""" | |||||
| def __init__(self): | |||||
| super(TensorScatterUpdateNet, self).__init__() | |||||
| self.tensor_scatter_update = P.TensorScatterUpdate() | |||||
| def construct(self, x, i, u): | |||||
| out = self.tensor_scatter_update(x, i, u) | |||||
| return out | |||||
| net = TensorScatterUpdateNet() | |||||
| context.set_context(mode=context.GRAPH_MODE, save_graphs=True) | |||||
| x = Tensor(np.arange(3 * 4 * 5).reshape((3, 4, 5)), mstype.float32) | |||||
| indices = Tensor(np.array([[0, 0], [1, 1]], np.int32)) | |||||
| updates = Tensor(np.ones([2, 5], np.float32)) | |||||
| net(x, indices, updates) | |||||
| class InputBackward(nn.Cell): | class InputBackward(nn.Cell): | ||||
| def __init__(self, network): | def __init__(self, network): | ||||
| super(InputBackward, self).__init__() | super(InputBackward, self).__init__() | ||||
| @@ -1537,6 +1556,12 @@ test_case_other_ops = [ | |||||
| 'desc_inputs': (Tensor(np.ones((2, 2), np.int32)), | 'desc_inputs': (Tensor(np.ones((2, 2), np.int32)), | ||||
| Tensor(np.ones((2,), np.int32))), | Tensor(np.ones((2,), np.int32))), | ||||
| 'desc_bprop': [([3, 3], {'dtype': np.int32})]}), | 'desc_bprop': [([3, 3], {'dtype': np.int32})]}), | ||||
| ('TensorScatterUpdate', { | |||||
| 'block': P.TensorScatterUpdate(), | |||||
| 'desc_inputs': (Tensor(np.arange(3 * 4 * 5).reshape((3, 4, 5)), mstype.float32), | |||||
| Tensor(np.array([[0, 1], [1, 2]], np.int32)), | |||||
| Tensor(np.ones([2, 5], np.float32) * 99)), | |||||
| 'desc_bprop': [([3, 4, 5], {'dtype': np.float32})]}), | |||||
| ('ScatterMax', { | ('ScatterMax', { | ||||
| 'block': ScatterMax(), | 'block': ScatterMax(), | ||||
| 'desc_inputs': (Tensor(np.array([[0, 0], [1, 1]], np.int32)), | 'desc_inputs': (Tensor(np.array([[0, 0], [1, 1]], np.int32)), | ||||