Browse Source

!1973 develop TensorScatterUpdate op and access ge and vm

Merge pull request !1973 from zhangbuxue/develop_TensorScatterUpdate_op_and_access_ge_and_vm
tags/v0.5.0-beta
mindspore-ci-bot Gitee 5 years ago
parent
commit
b285c8fa9d
11 changed files with 137 additions and 5 deletions
  1. +2
    -0
      mindspore/ccsrc/transform/convert.cc
  2. +5
    -0
      mindspore/ccsrc/transform/op_declare.cc
  3. +2
    -0
      mindspore/ccsrc/transform/op_declare.h
  4. +14
    -0
      mindspore/ops/_grad/grad_array_ops.py
  5. +1
    -0
      mindspore/ops/_op_impl/tbe/__init__.py
  6. +1
    -1
      mindspore/ops/_op_impl/tbe/scatter_nd_update.py
  7. +1
    -1
      mindspore/ops/_op_impl/tbe/scatter_update.py
  8. +41
    -0
      mindspore/ops/_op_impl/tbe/tensor_scatter_update.py
  9. +2
    -1
      mindspore/ops/operations/__init__.py
  10. +43
    -2
      mindspore/ops/operations/array_ops.py
  11. +25
    -0
      tests/ut/python/ops/test_ops.py

+ 2
- 0
mindspore/ccsrc/transform/convert.cc View File

@@ -103,6 +103,7 @@ const char kNameReLU6[] = "ReLU6";
const char kNameReLU6Grad[] = "ReLU6Grad"; const char kNameReLU6Grad[] = "ReLU6Grad";
const char kNameElu[] = "Elu"; const char kNameElu[] = "Elu";
const char kNameEluGrad[] = "EluGrad"; const char kNameEluGrad[] = "EluGrad";
const char kNameTensorScatterUpdate[] = "TensorScatterUpdate";
const char kNameScatterUpdate[] = "ScatterUpdate"; const char kNameScatterUpdate[] = "ScatterUpdate";
const char kNameScatterNdUpdate[] = "ScatterNdUpdate"; const char kNameScatterNdUpdate[] = "ScatterNdUpdate";
const char kNameScatterMax[] = "ScatterMax"; const char kNameScatterMax[] = "ScatterMax";
@@ -261,6 +262,7 @@ std::unordered_map<std::string, OpAdapterDescPtr> &DfGraphConvertor::get_adpt_ma
{string(kNameResizeBilinear), ADPT_DESC(ResizeBilinearV2D)}, {string(kNameResizeBilinear), ADPT_DESC(ResizeBilinearV2D)},
{string(kNameZerosLike), ADPT_DESC(ZerosLike)}, {string(kNameZerosLike), ADPT_DESC(ZerosLike)},
{string(kNameOnesLike), ADPT_DESC(OnesLike)}, {string(kNameOnesLike), ADPT_DESC(OnesLike)},
{string(kNameTensorScatterUpdate), ADPT_DESC(TensorScatterUpdate)},
{string(kNameScatterUpdate), ADPT_DESC(ScatterUpdate)}, {string(kNameScatterUpdate), ADPT_DESC(ScatterUpdate)},
{string(kNameScatterNdUpdate), ADPT_DESC(ScatterNdUpdate)}, {string(kNameScatterNdUpdate), ADPT_DESC(ScatterNdUpdate)},
{string(kNameScatterMax), ADPT_DESC(ScatterMax)}, {string(kNameScatterMax), ADPT_DESC(ScatterMax)},


+ 5
- 0
mindspore/ccsrc/transform/op_declare.cc View File

@@ -525,6 +525,11 @@ INPUT_MAP(Unpack) = {{1, INPUT_DESC(x)}};
ATTR_MAP(Unpack) = {{"axis", ATTR_DESC(axis, AnyTraits<int>())}, {"num", ATTR_DESC(num, AnyTraits<int>())}}; ATTR_MAP(Unpack) = {{"axis", ATTR_DESC(axis, AnyTraits<int>())}, {"num", ATTR_DESC(num, AnyTraits<int>())}};
DYN_OUTPUT_MAP(Unpack) = {{0, DYN_OUTPUT_DESC(y)}}; DYN_OUTPUT_MAP(Unpack) = {{0, DYN_OUTPUT_DESC(y)}};


// TensorScatterUpdate
INPUT_MAP(TensorScatterUpdate) = {{1, INPUT_DESC(x)}, {2, INPUT_DESC(indices)}, {3, INPUT_DESC(updates)}};
ATTR_MAP(TensorScatterUpdate) = EMPTY_ATTR_MAP;
OUTPUT_MAP(TensorScatterUpdate) = {{0, OUTPUT_DESC(y)}};

// ScatterUpdate // ScatterUpdate
INPUT_MAP(ScatterUpdate) = {{1, INPUT_DESC(var)}, {2, INPUT_DESC(indices)}, {3, INPUT_DESC(updates)}}; INPUT_MAP(ScatterUpdate) = {{1, INPUT_DESC(var)}, {2, INPUT_DESC(indices)}, {3, INPUT_DESC(updates)}};
ATTR_MAP(ScatterUpdate) = {{"use_locking", ATTR_DESC(use_locking, AnyTraits<bool>())}}; ATTR_MAP(ScatterUpdate) = {{"use_locking", ATTR_DESC(use_locking, AnyTraits<bool>())}};


+ 2
- 0
mindspore/ccsrc/transform/op_declare.h View File

@@ -134,6 +134,8 @@ DECLARE_OP_ADAPTER(ZerosLike)
DECLARE_OP_USE_OUTPUT(ZerosLike) DECLARE_OP_USE_OUTPUT(ZerosLike)
DECLARE_OP_ADAPTER(OnesLike) DECLARE_OP_ADAPTER(OnesLike)
DECLARE_OP_USE_OUTPUT(OnesLike) DECLARE_OP_USE_OUTPUT(OnesLike)
DECLARE_OP_ADAPTER(TensorScatterUpdate)
DECLARE_OP_USE_OUTPUT(TensorScatterUpdate)
DECLARE_OP_ADAPTER(ScatterUpdate) DECLARE_OP_ADAPTER(ScatterUpdate)
DECLARE_OP_USE_OUTPUT(ScatterUpdate) DECLARE_OP_USE_OUTPUT(ScatterUpdate)
DECLARE_OP_ADAPTER(ScatterNdUpdate) DECLARE_OP_ADAPTER(ScatterNdUpdate)


+ 14
- 0
mindspore/ops/_grad/grad_array_ops.py View File

@@ -456,6 +456,20 @@ def get_bprop_scatter_nd_update(self):
return bprop return bprop




@bprop_getters.register(P.TensorScatterUpdate)
def get_bprop_tensor_scatter_update(self):
"""Generate bprop for TensorScatterUpdate"""
gather_nd = P.GatherNd()
tensor_scatter_update = P.TensorScatterUpdate()

def bprop(x, indices, update, out, dout):
x_grad = tensor_scatter_update(dout, indices, zeros_like(update))
update_grad = gather_nd(dout, indices)
return x_grad, zeros_like(indices), update_grad

return bprop


@bprop_getters.register(P.Argmax) @bprop_getters.register(P.Argmax)
def get_bprop_argmax(self): def get_bprop_argmax(self):
"""Generate bprop for Argmax""" """Generate bprop for Argmax"""


+ 1
- 0
mindspore/ops/_op_impl/tbe/__init__.py View File

@@ -255,3 +255,4 @@ from .lamb_next_right import _lamb_next_right_tbe
from .sparse_gather_v2 import _sparse_gather_v2_tbe from .sparse_gather_v2 import _sparse_gather_v2_tbe
from .data_format_dim_map import _data_format_dim_map_tbe from .data_format_dim_map import _data_format_dim_map_tbe
from .histogram_fixed_width import _histogram_fixed_width_tbe from .histogram_fixed_width import _histogram_fixed_width_tbe
from .tensor_scatter_update import _tensor_scatter_update_tbe

+ 1
- 1
mindspore/ops/_op_impl/tbe/scatter_nd_update.py View File

@@ -31,7 +31,7 @@ scatter_nd_update_op_info = TBERegOp("ScatterNdUpdate") \
.dtype_format(DataType.F16_Default, DataType.I32_Default, DataType.F16_Default, DataType.F16_Default) \ .dtype_format(DataType.F16_Default, DataType.I32_Default, DataType.F16_Default, DataType.F16_Default) \
.dtype_format(DataType.F32_Default, DataType.I32_Default, DataType.F32_Default, DataType.F32_Default) \ .dtype_format(DataType.F32_Default, DataType.I32_Default, DataType.F32_Default, DataType.F32_Default) \
.dtype_format(DataType.I8_Default, DataType.I32_Default, DataType.I8_Default, DataType.I8_Default) \ .dtype_format(DataType.I8_Default, DataType.I32_Default, DataType.I8_Default, DataType.I8_Default) \
.dtype_format(DataType.U8_Default, DataType.I32_Default, DataType.U8_Default, DataType.U8_Default,) \
.dtype_format(DataType.U8_Default, DataType.I32_Default, DataType.U8_Default, DataType.U8_Default) \
.dtype_format(DataType.BOOL_Default, DataType.I32_Default, DataType.BOOL_Default, DataType.BOOL_Default) \ .dtype_format(DataType.BOOL_Default, DataType.I32_Default, DataType.BOOL_Default, DataType.BOOL_Default) \
.get_op_info() .get_op_info()




+ 1
- 1
mindspore/ops/_op_impl/tbe/scatter_update.py View File

@@ -31,7 +31,7 @@ scatter_update_op_info = TBERegOp("ScatterUpdate") \
.dtype_format(DataType.F16_Default, DataType.I32_Default, DataType.F16_Default, DataType.F16_Default) \ .dtype_format(DataType.F16_Default, DataType.I32_Default, DataType.F16_Default, DataType.F16_Default) \
.dtype_format(DataType.F32_Default, DataType.I32_Default, DataType.F32_Default, DataType.F32_Default) \ .dtype_format(DataType.F32_Default, DataType.I32_Default, DataType.F32_Default, DataType.F32_Default) \
.dtype_format(DataType.I8_Default, DataType.I32_Default, DataType.I8_Default, DataType.I8_Default) \ .dtype_format(DataType.I8_Default, DataType.I32_Default, DataType.I8_Default, DataType.I8_Default) \
.dtype_format(DataType.U8_Default, DataType.I32_Default, DataType.U8_Default, DataType.U8_Default,) \
.dtype_format(DataType.U8_Default, DataType.I32_Default, DataType.U8_Default, DataType.U8_Default) \
.dtype_format(DataType.BOOL_Default, DataType.I32_Default, DataType.BOOL_Default, DataType.BOOL_Default) \ .dtype_format(DataType.BOOL_Default, DataType.I32_Default, DataType.BOOL_Default, DataType.BOOL_Default) \
.get_op_info() .get_op_info()




+ 41
- 0
mindspore/ops/_op_impl/tbe/tensor_scatter_update.py View File

@@ -0,0 +1,41 @@
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================

"""TensorScatterUpdate op"""
from mindspore.ops.op_info_register import op_info_register, TBERegOp, DataType

tensor_scatter_update_op_info = TBERegOp("TensorScatterUpdate") \
.fusion_type("ELEMWISE") \
.async_flag(False) \
.binfile_name("tensor_scatter_update.so") \
.compute_cost(10) \
.kernel_name("tensor_scatter_update") \
.partial_flag(True) \
.input(0, "x", False, "required", "all") \
.input(1, "indices", False, "required", "all") \
.input(1, "updates", False, "required", "all") \
.output(0, "y", False, "required", "all") \
.dtype_format(DataType.F16_Default, DataType.I32_Default, DataType.F16_Default, DataType.F16_Default) \
.dtype_format(DataType.F32_Default, DataType.I32_Default, DataType.F32_Default, DataType.F32_Default) \
.dtype_format(DataType.I8_Default, DataType.I32_Default, DataType.I8_Default, DataType.I8_Default) \
.dtype_format(DataType.U8_Default, DataType.I32_Default, DataType.U8_Default, DataType.U8_Default) \
.dtype_format(DataType.I32_Default, DataType.I32_Default, DataType.I32_Default, DataType.I32_Default) \
.get_op_info()


@op_info_register(tensor_scatter_update_op_info)
def _tensor_scatter_update_tbe():
"""TensorScatterUpdate TBE register"""
return

+ 2
- 1
mindspore/ops/operations/__init__.py View File

@@ -27,7 +27,7 @@ from .array_ops import (Argmax, Argmin, Cast, Concat, Pack, Unpack,
SameTypeShape, ScatterAdd, ScatterMax, ScatterUpdate, SameTypeShape, ScatterAdd, ScatterMax, ScatterUpdate,
ScalarToArray, ScalarToTensor, ScatterNd, ScatterNdUpdate, Select, ScalarToArray, ScalarToTensor, ScatterNd, ScatterNdUpdate, Select,
Shape, Size, Slice, Split, EmbeddingLookup, Shape, Size, Slice, Split, EmbeddingLookup,
Squeeze, StridedSlice, Tile,
Squeeze, StridedSlice, Tile, TensorScatterUpdate,
Transpose, TruncatedNormal, TupleToArray, UnsortedSegmentMin, Transpose, TruncatedNormal, TupleToArray, UnsortedSegmentMin,
UnsortedSegmentSum, SpaceToDepth, DepthToSpace, SpaceToBatch, BatchToSpace, UnsortedSegmentSum, SpaceToDepth, DepthToSpace, SpaceToBatch, BatchToSpace,
SpaceToBatchND, BatchToSpaceND, BroadcastTo) SpaceToBatchND, BatchToSpaceND, BroadcastTo)
@@ -212,6 +212,7 @@ __all__ = [
'Pad', 'Pad',
'MirrorPad', 'MirrorPad',
'GatherNd', 'GatherNd',
'TensorScatterUpdate',
'ScatterUpdate', 'ScatterUpdate',
'ScatterNdUpdate', 'ScatterNdUpdate',
'Floor', 'Floor',


+ 43
- 2
mindspore/ops/operations/array_ops.py View File

@@ -2187,6 +2187,47 @@ class GatherNd(PrimitiveWithInfer):
return x_dtype return x_dtype




class TensorScatterUpdate(PrimitiveWithInfer):
"""
Update tensor value by using input indices and value.

Using given values to update tensor value, along with the input indices.

Inputs:
- **input_x** (Tensor) - The target tensor.
- **indices** (Tensor) - The index of input tensor whose data type is int32.
- **update** (Tensor) - The tensor to update the input tensor, has the same type as input,
and update.shape = indices.shape + input_x.shape[1:].

Outputs:
Tensor, has the same shape and type as `input_x`.

Examples:
>>> input_x = Tensor(np.array([[-0.1, 0.3, 3.6], [0.4, 0.5, -3.2]]), mindspore.float32)
>>> indices = Tensor(np.array([[0, 0], [1, 1]]), mindspore.int32)
>>> update = Tensor(np.array([1.0, 2.2]), mindspore.float32)
>>> op = P.TensorScatterUpdate()
>>> output = op(input_x, indices, update)
"""
@prim_attr_register
def __init__(self):
"""Init TensorScatterUpdate"""
self.init_prim_io_names(inputs=['x', 'indices', 'value'], outputs=['y'])

def infer_shape(self, x_shape, indices_shape, value_shape):
validator.check('the dimension of x', len(x_shape),
'the dimension of indices', indices_shape[-1], Rel.GE)
if indices_shape[:-1] + x_shape[indices_shape[-1]:] != value_shape:
raise ValueError("For 'TensorScatterUpdate', input value are not match with input indices.")
return x_shape

def infer_dtype(self, x_dtype, indices_dtype, value_dtype):
validator.check_tensor_type_same({'indices': indices_dtype}, [mstype.int32], self.name)
args = {"x": x_dtype, "value": value_dtype}
validator.check_tensor_type_same(args, (mstype.bool_,) + mstype.number_type, self.name)
return x_dtype


class ScatterUpdate(PrimitiveWithInfer): class ScatterUpdate(PrimitiveWithInfer):
""" """
Update tensor value by using input indices and value. Update tensor value by using input indices and value.
@@ -2227,7 +2268,7 @@ class ScatterUpdate(PrimitiveWithInfer):


def infer_shape(self, x_shape, indices_shape, value_shape): def infer_shape(self, x_shape, indices_shape, value_shape):
if indices_shape + x_shape[1:] != value_shape: if indices_shape + x_shape[1:] != value_shape:
raise ValueError('Input value are not match with input indices.')
raise ValueError("For 'ScatterUpdate', input value are not match with input indices.")
return x_shape return x_shape


def infer_dtype(self, x_dtype, indices_dtype, value_dtype): def infer_dtype(self, x_dtype, indices_dtype, value_dtype):
@@ -2277,7 +2318,7 @@ class ScatterNdUpdate(PrimitiveWithInfer):
validator.check('the dimension of x', len(x_shape), validator.check('the dimension of x', len(x_shape),
'the dimension of indices', indices_shape[-1], Rel.GE) 'the dimension of indices', indices_shape[-1], Rel.GE)
if indices_shape[:-1] + x_shape[indices_shape[-1]:] != value_shape: if indices_shape[:-1] + x_shape[indices_shape[-1]:] != value_shape:
raise ValueError('Input value are not match with input indices.')
raise ValueError("For 'ScatterNdUpdate', input value are not match with input indices.")
return x_shape return x_shape


def infer_dtype(self, x_dtype, indices_dtype, value_dtype): def infer_dtype(self, x_dtype, indices_dtype, value_dtype):


+ 25
- 0
tests/ut/python/ops/test_ops.py View File

@@ -34,6 +34,25 @@ from ....mindspore_test_framework.pipeline.gradient.compile_gradient \
import pipeline_for_compile_grad_ge_graph_for_case_by_case_config import pipeline_for_compile_grad_ge_graph_for_case_by_case_config




def test_tensor_scatter_update():
class TensorScatterUpdateNet(nn.Cell):
"""TensorScatterUpdate net definition"""

def __init__(self):
super(TensorScatterUpdateNet, self).__init__()
self.tensor_scatter_update = P.TensorScatterUpdate()

def construct(self, x, i, u):
out = self.tensor_scatter_update(x, i, u)
return out
net = TensorScatterUpdateNet()
context.set_context(mode=context.GRAPH_MODE, save_graphs=True)
x = Tensor(np.arange(3 * 4 * 5).reshape((3, 4, 5)), mstype.float32)
indices = Tensor(np.array([[0, 0], [1, 1]], np.int32))
updates = Tensor(np.ones([2, 5], np.float32))
net(x, indices, updates)


class InputBackward(nn.Cell): class InputBackward(nn.Cell):
def __init__(self, network): def __init__(self, network):
super(InputBackward, self).__init__() super(InputBackward, self).__init__()
@@ -1537,6 +1556,12 @@ test_case_other_ops = [
'desc_inputs': (Tensor(np.ones((2, 2), np.int32)), 'desc_inputs': (Tensor(np.ones((2, 2), np.int32)),
Tensor(np.ones((2,), np.int32))), Tensor(np.ones((2,), np.int32))),
'desc_bprop': [([3, 3], {'dtype': np.int32})]}), 'desc_bprop': [([3, 3], {'dtype': np.int32})]}),
('TensorScatterUpdate', {
'block': P.TensorScatterUpdate(),
'desc_inputs': (Tensor(np.arange(3 * 4 * 5).reshape((3, 4, 5)), mstype.float32),
Tensor(np.array([[0, 1], [1, 2]], np.int32)),
Tensor(np.ones([2, 5], np.float32) * 99)),
'desc_bprop': [([3, 4, 5], {'dtype': np.float32})]}),
('ScatterMax', { ('ScatterMax', {
'block': ScatterMax(), 'block': ScatterMax(),
'desc_inputs': (Tensor(np.array([[0, 0], [1, 1]], np.int32)), 'desc_inputs': (Tensor(np.array([[0, 0], [1, 1]], np.int32)),


Loading…
Cancel
Save