| @@ -28,7 +28,7 @@ from .multitype_ops.ones_like_impl import ones_like | |||
| from .multitype_ops.zeros_like_impl import zeros_like | |||
| from .random_ops import normal, laplace, uniform, gamma, poisson, multinomial | |||
| from .math_ops import count_nonzero, tensor_dot | |||
| from .array_ops import repeat_elements | |||
| from .array_ops import repeat_elements, sequence_mask | |||
| __all__ = [ | |||
| @@ -53,4 +53,5 @@ __all__ = [ | |||
| 'clip_by_global_norm', | |||
| 'count_nonzero', | |||
| 'tensor_dot', | |||
| 'repeat_elements'] | |||
| 'repeat_elements', | |||
| 'sequence_mask'] | |||
| @@ -20,6 +20,7 @@ from mindspore._checkparam import Rel | |||
| from mindspore.ops.primitive import constexpr | |||
| from mindspore.ops import functional as F | |||
| from .. import operations as P | |||
| from ..operations import _inner_ops as inner | |||
| @constexpr | |||
| @@ -103,3 +104,35 @@ def repeat_elements(x, rep, axis=0): | |||
| x_rep = reshape_op(x_expand, x_reshape) | |||
| return x_rep | |||
| def sequence_mask(lengths, maxlen): | |||
| """ | |||
| Returns a mask tensor representing the first N positions of each cell. | |||
| If lengths has shape [d_1, d_2, ..., d_n], then the resulting tensor mask has type dtype and shape | |||
| [d_1, d_2, ..., d_n, maxlen], with mask[i_1, i_2, ..., i_n, j] = (j < lengths[i_1, i_2, ..., i_n]) | |||
| Args: | |||
| length (Tensor): Tensor to calculate the mask for. All values in this tensor must be | |||
| less than `maxlen`. Must be type int32 or int64. | |||
| maxlen (int): size of the last dimension of returned tensor. Must be positive and same | |||
| type as elements in `lengths`. | |||
| Outputs: | |||
| One mask tensor of shape lengths.shape + (maxlen,). | |||
| Supported Platforms: | |||
| ``GPU`` | |||
| Examples: | |||
| >>> x = Tensor(np.array([[1, 3], [2, 0]]) | |||
| >>> sequence_mask = P.SequenceMask() | |||
| >>> output = sequence_mask(x, 3) | |||
| >>> print(output) | |||
| [[[True, False, False], | |||
| [True, True, True]], | |||
| [[True, True, False], | |||
| [False, False, False]]] | |||
| """ | |||
| return inner.SequenceMask()(lengths, maxlen) | |||
| @@ -33,7 +33,7 @@ from .array_ops import (Argmax, Argmin, Cast, Concat, Pack, Unpack, | |||
| Transpose, TruncatedNormal, TupleToArray, UnsortedSegmentMin, UnsortedSegmentMax, | |||
| UnsortedSegmentProd, UnsortedSegmentSum, SpaceToDepth, DepthToSpace, SpaceToBatch, BatchToSpace, | |||
| SpaceToBatchND, BatchToSpaceND, BroadcastTo, InplaceUpdate, ReverseSequence, EmbeddingLookup, | |||
| Unique, GatherD, Identity, SequenceMask) | |||
| Unique, GatherD, Identity) | |||
| from .comm_ops import (AllGather, AllReduce, _AlltoAll, AllSwap, ReduceScatter, Broadcast, | |||
| _MirrorOperator, ReduceOp, _VirtualDataset, | |||
| _VirtualDiv, _GetTensorSlice, | |||
| @@ -400,7 +400,6 @@ __all__ = [ | |||
| "Pull", | |||
| "ReLUV2", | |||
| "SparseToDense", | |||
| "SequenceMask", | |||
| ] | |||
| __all__.sort() | |||
| @@ -679,3 +679,47 @@ class ErrorOnDynamicShapeInput(PrimitiveWithInfer): | |||
| def infer_value(self, input_tensor): | |||
| return input_tensor | |||
| class SequenceMask(PrimitiveWithCheck): | |||
| """ | |||
| Returns a mask tensor representing the first N positions of each cell. | |||
| If lengths has shape [d_1, d_2, ..., d_n], then the resulting tensor mask has type dtype and shape | |||
| [d_1, d_2, ..., d_n, maxlen], with mask[i_1, i_2, ..., i_n, j] = (j < lengths[i_1, i_2, ..., i_n]) | |||
| Inputs: | |||
| - **lengths** (Tensor) - Tensor to calculate the mask for. All values in this tensor must be | |||
| less than `maxlen`. Must be type int32 or int64. | |||
| - **maxlen** (int) - size of the last dimension of returned tensor. Must be positive and same | |||
| type as elements in `lengths`. | |||
| Outputs: | |||
| One mask tensor of shape lengths.shape + (maxlen,). | |||
| Supported Platforms: | |||
| ``GPU`` | |||
| Examples: | |||
| >>> x = Tensor(np.array([[1, 3], [2, 0]]) | |||
| >>> sequence_mask = P.SequenceMask() | |||
| >>> output = sequence_mask(x, 3) | |||
| >>> print(output) | |||
| [[[True, False, False], | |||
| [True, True, True]], | |||
| [[True, True, False], | |||
| [False, False, False]]] | |||
| """ | |||
| @prim_attr_register | |||
| def __init__(self): | |||
| self.init_prim_io_names(inputs=["lengths", "maxlen"], outputs=["mask"]) | |||
| def check_shape(self, lengths_shape, maxlen_shape): | |||
| validator.check("lengths_shape", len(lengths_shape), "", 0, Rel.GT, self.name) | |||
| validator.check("maxlen_shape", len(maxlen_shape), "", 0, Rel.EQ, self.name) | |||
| def check_dtype(self, lengths_dtype, maxlen_dtype): | |||
| validator.check_subclass("lengths_dtype", lengths_dtype, mstype.tensor, self.name) | |||
| validator.check_subclass("maxlen", maxlen_dtype, mstype.number, self.name) | |||
| @@ -4720,47 +4720,3 @@ class Identity(PrimitiveWithInfer): | |||
| 'dtype': x['dtype'], | |||
| 'value': None} | |||
| return out | |||
| class SequenceMask(PrimitiveWithCheck): | |||
| """ | |||
| Returns a mask tensor representing the first N positions of each cell. | |||
| If lengths has shape [d_1, d_2, ..., d_n], then the resulting tensor mask has type dtype and shape | |||
| [d_1, d_2, ..., d_n, maxlen], with mask[i_1, i_2, ..., i_n, j] = (j < lengths[i_1, i_2, ..., i_n]) | |||
| Inputs: | |||
| - **lengths** (Tensor) - Tensor to calculate the mask for. All values in this tensor must be | |||
| less than `maxlen`. Must be type int32 or int64. | |||
| - **maxlen** (int) - size of the last dimension of returned tensor. Must be positive and same | |||
| tyupe as elements in `lengths`. | |||
| Outputs: | |||
| One mask tensor of shape lengths.shape + (maxlen,). | |||
| Supported Platforms: | |||
| ``GPU`` | |||
| Examples: | |||
| >>> x = Tensor(np.array([[1, 3], [2, 0]]) | |||
| >>> sequence_mask = P.SequenceMask() | |||
| >>> output = sequence_mask(x, 3) | |||
| >>> print(output) | |||
| [[[True, False, False], | |||
| [True, True, True]], | |||
| [[True, True, False], | |||
| [False, False, False]]] | |||
| """ | |||
| @prim_attr_register | |||
| def __init__(self): | |||
| self.init_prim_io_names(inputs=["lengths", "maxlen"], outputs=["mask"]) | |||
| def check_shape(self, lengths_shape, maxlen_shape): | |||
| validator.check("lengths_shape", len(lengths_shape), "", 0, Rel.GT, self.name) | |||
| validator.check("maxlen_shape", len(maxlen_shape), "", 0, Rel.EQ, self.name) | |||
| def check_dtype(self, lengths_dtype, maxlen_dtype): | |||
| validator.check_subclass("lengths_dtype", lengths_dtype, mstype.tensor, self.name) | |||
| validator.check_subclass("maxlen", maxlen_dtype, mstype.number, self.name) | |||
| @@ -2,14 +2,13 @@ import numpy as np | |||
| import pytest | |||
| from mindspore import Tensor | |||
| from mindspore.ops import operations as P | |||
| from mindspore.ops import composite as C | |||
| from mindspore.ops.operations import _inner_ops as inner | |||
| import mindspore.nn as nn | |||
| import mindspore.context as context | |||
| def sequence_mask(x, maxlen): | |||
| sequence_mask_op = P.SequenceMask() | |||
| return sequence_mask_op(Tensor(x.astype(np.int32)), maxlen) | |||
| return C.sequence_mask(Tensor(x.astype(np.int32)), maxlen) | |||
| @pytest.mark.level0 | |||
| @pytest.mark.platform_x86_gpu_training | |||
| @@ -87,11 +86,10 @@ def test_sequence_mask_dynamic(): | |||
| super(SequenceMaskDynamicNet, self).__init__() | |||
| self.maxlen = maxlen | |||
| self.convert_to_dynamic_shape = inner.GpuConvertToDynamicShape() | |||
| self.sequence_mask = P.SequenceMask() | |||
| def construct(self, x): | |||
| converted_to_dynamic_shape = self.convert_to_dynamic_shape(x) | |||
| return self.sequence_mask(converted_to_dynamic_shape, self.maxlen) | |||
| return C.sequence_mask(converted_to_dynamic_shape, self.maxlen) | |||
| context.set_context(mode=context.GRAPH_MODE, device_target="GPU") | |||