Browse Source

wrap sequenceMask in composite

tags/v1.1.0
Peilin Wang 5 years ago
parent
commit
412665f707
6 changed files with 84 additions and 53 deletions
  1. +3
    -2
      mindspore/ops/composite/__init__.py
  2. +33
    -0
      mindspore/ops/composite/array_ops.py
  3. +1
    -2
      mindspore/ops/operations/__init__.py
  4. +44
    -0
      mindspore/ops/operations/_inner_ops.py
  5. +0
    -44
      mindspore/ops/operations/array_ops.py
  6. +3
    -5
      tests/st/ops/gpu/test_sequence_mask_op.py

+ 3
- 2
mindspore/ops/composite/__init__.py View File

@@ -28,7 +28,7 @@ from .multitype_ops.ones_like_impl import ones_like
from .multitype_ops.zeros_like_impl import zeros_like
from .random_ops import normal, laplace, uniform, gamma, poisson, multinomial
from .math_ops import count_nonzero, tensor_dot
from .array_ops import repeat_elements
from .array_ops import repeat_elements, sequence_mask


__all__ = [
@@ -53,4 +53,5 @@ __all__ = [
'clip_by_global_norm',
'count_nonzero',
'tensor_dot',
'repeat_elements']
'repeat_elements',
'sequence_mask']

+ 33
- 0
mindspore/ops/composite/array_ops.py View File

@@ -20,6 +20,7 @@ from mindspore._checkparam import Rel
from mindspore.ops.primitive import constexpr
from mindspore.ops import functional as F
from .. import operations as P
from ..operations import _inner_ops as inner


@constexpr
@@ -103,3 +104,35 @@ def repeat_elements(x, rep, axis=0):
x_rep = reshape_op(x_expand, x_reshape)

return x_rep

def sequence_mask(lengths, maxlen):
"""
Returns a mask tensor representing the first N positions of each cell.

If lengths has shape [d_1, d_2, ..., d_n], then the resulting tensor mask has type dtype and shape
[d_1, d_2, ..., d_n, maxlen], with mask[i_1, i_2, ..., i_n, j] = (j < lengths[i_1, i_2, ..., i_n])

Args:
length (Tensor): Tensor to calculate the mask for. All values in this tensor must be
less than `maxlen`. Must be type int32 or int64.

maxlen (int): size of the last dimension of returned tensor. Must be positive and same
type as elements in `lengths`.

Outputs:
One mask tensor of shape lengths.shape + (maxlen,).

Supported Platforms:
``GPU``

Examples:
>>> x = Tensor(np.array([[1, 3], [2, 0]])
>>> sequence_mask = P.SequenceMask()
>>> output = sequence_mask(x, 3)
>>> print(output)
[[[True, False, False],
[True, True, True]],
[[True, True, False],
[False, False, False]]]
"""
return inner.SequenceMask()(lengths, maxlen)

+ 1
- 2
mindspore/ops/operations/__init__.py View File

@@ -33,7 +33,7 @@ from .array_ops import (Argmax, Argmin, Cast, Concat, Pack, Unpack,
Transpose, TruncatedNormal, TupleToArray, UnsortedSegmentMin, UnsortedSegmentMax,
UnsortedSegmentProd, UnsortedSegmentSum, SpaceToDepth, DepthToSpace, SpaceToBatch, BatchToSpace,
SpaceToBatchND, BatchToSpaceND, BroadcastTo, InplaceUpdate, ReverseSequence, EmbeddingLookup,
Unique, GatherD, Identity, SequenceMask)
Unique, GatherD, Identity)
from .comm_ops import (AllGather, AllReduce, _AlltoAll, AllSwap, ReduceScatter, Broadcast,
_MirrorOperator, ReduceOp, _VirtualDataset,
_VirtualDiv, _GetTensorSlice,
@@ -400,7 +400,6 @@ __all__ = [
"Pull",
"ReLUV2",
"SparseToDense",
"SequenceMask",
]

__all__.sort()

+ 44
- 0
mindspore/ops/operations/_inner_ops.py View File

@@ -679,3 +679,47 @@ class ErrorOnDynamicShapeInput(PrimitiveWithInfer):

def infer_value(self, input_tensor):
return input_tensor


class SequenceMask(PrimitiveWithCheck):
"""
Returns a mask tensor representing the first N positions of each cell.

If lengths has shape [d_1, d_2, ..., d_n], then the resulting tensor mask has type dtype and shape
[d_1, d_2, ..., d_n, maxlen], with mask[i_1, i_2, ..., i_n, j] = (j < lengths[i_1, i_2, ..., i_n])

Inputs:
- **lengths** (Tensor) - Tensor to calculate the mask for. All values in this tensor must be
less than `maxlen`. Must be type int32 or int64.

- **maxlen** (int) - size of the last dimension of returned tensor. Must be positive and same
type as elements in `lengths`.

Outputs:
One mask tensor of shape lengths.shape + (maxlen,).

Supported Platforms:
``GPU``

Examples:
>>> x = Tensor(np.array([[1, 3], [2, 0]])
>>> sequence_mask = P.SequenceMask()
>>> output = sequence_mask(x, 3)
>>> print(output)
[[[True, False, False],
[True, True, True]],
[[True, True, False],
[False, False, False]]]
"""

@prim_attr_register
def __init__(self):
self.init_prim_io_names(inputs=["lengths", "maxlen"], outputs=["mask"])

def check_shape(self, lengths_shape, maxlen_shape):
validator.check("lengths_shape", len(lengths_shape), "", 0, Rel.GT, self.name)
validator.check("maxlen_shape", len(maxlen_shape), "", 0, Rel.EQ, self.name)

def check_dtype(self, lengths_dtype, maxlen_dtype):
validator.check_subclass("lengths_dtype", lengths_dtype, mstype.tensor, self.name)
validator.check_subclass("maxlen", maxlen_dtype, mstype.number, self.name)

+ 0
- 44
mindspore/ops/operations/array_ops.py View File

@@ -4720,47 +4720,3 @@ class Identity(PrimitiveWithInfer):
'dtype': x['dtype'],
'value': None}
return out


class SequenceMask(PrimitiveWithCheck):
"""
Returns a mask tensor representing the first N positions of each cell.

If lengths has shape [d_1, d_2, ..., d_n], then the resulting tensor mask has type dtype and shape
[d_1, d_2, ..., d_n, maxlen], with mask[i_1, i_2, ..., i_n, j] = (j < lengths[i_1, i_2, ..., i_n])

Inputs:
- **lengths** (Tensor) - Tensor to calculate the mask for. All values in this tensor must be
less than `maxlen`. Must be type int32 or int64.

- **maxlen** (int) - size of the last dimension of returned tensor. Must be positive and same
tyupe as elements in `lengths`.

Outputs:
One mask tensor of shape lengths.shape + (maxlen,).

Supported Platforms:
``GPU``

Examples:
>>> x = Tensor(np.array([[1, 3], [2, 0]])
>>> sequence_mask = P.SequenceMask()
>>> output = sequence_mask(x, 3)
>>> print(output)
[[[True, False, False],
[True, True, True]],
[[True, True, False],
[False, False, False]]]
"""

@prim_attr_register
def __init__(self):
self.init_prim_io_names(inputs=["lengths", "maxlen"], outputs=["mask"])

def check_shape(self, lengths_shape, maxlen_shape):
validator.check("lengths_shape", len(lengths_shape), "", 0, Rel.GT, self.name)
validator.check("maxlen_shape", len(maxlen_shape), "", 0, Rel.EQ, self.name)

def check_dtype(self, lengths_dtype, maxlen_dtype):
validator.check_subclass("lengths_dtype", lengths_dtype, mstype.tensor, self.name)
validator.check_subclass("maxlen", maxlen_dtype, mstype.number, self.name)

+ 3
- 5
tests/st/ops/gpu/test_sequence_mask_op.py View File

@@ -2,14 +2,13 @@ import numpy as np
import pytest

from mindspore import Tensor
from mindspore.ops import operations as P
from mindspore.ops import composite as C
from mindspore.ops.operations import _inner_ops as inner
import mindspore.nn as nn
import mindspore.context as context

def sequence_mask(x, maxlen):
sequence_mask_op = P.SequenceMask()
return sequence_mask_op(Tensor(x.astype(np.int32)), maxlen)
return C.sequence_mask(Tensor(x.astype(np.int32)), maxlen)

@pytest.mark.level0
@pytest.mark.platform_x86_gpu_training
@@ -87,11 +86,10 @@ def test_sequence_mask_dynamic():
super(SequenceMaskDynamicNet, self).__init__()
self.maxlen = maxlen
self.convert_to_dynamic_shape = inner.GpuConvertToDynamicShape()
self.sequence_mask = P.SequenceMask()

def construct(self, x):
converted_to_dynamic_shape = self.convert_to_dynamic_shape(x)
return self.sequence_mask(converted_to_dynamic_shape, self.maxlen)
return C.sequence_mask(converted_to_dynamic_shape, self.maxlen)

context.set_context(mode=context.GRAPH_MODE, device_target="GPU")



Loading…
Cancel
Save