Browse Source

add op: ReverseSequence

tags/v0.6.0-beta
xutianchun 5 years ago
parent
commit
56913ff1cf
4 changed files with 132 additions and 1 deletions
  1. +1
    -0
      mindspore/ops/_op_impl/aicpu/__init__.py
  2. +78
    -0
      mindspore/ops/_op_impl/aicpu/reverse_sequence.py
  3. +2
    -1
      mindspore/ops/operations/__init__.py
  4. +51
    -0
      mindspore/ops/operations/array_ops.py

+ 1
- 0
mindspore/ops/_op_impl/aicpu/__init__.py View File

@@ -27,3 +27,4 @@ from .random_choice_with_mask import _random_choice_with_mask_aicpu
from .ctcloss import _ctcloss_aicpu from .ctcloss import _ctcloss_aicpu
from .rnnt_loss import _rnnt_loss_aicpu from .rnnt_loss import _rnnt_loss_aicpu
from .random_categorical import _random_categorical_aicpu from .random_categorical import _random_categorical_aicpu
from .reverse_sequence import _reverse_sequence_aicpu

+ 78
- 0
mindspore/ops/_op_impl/aicpu/reverse_sequence.py View File

@@ -0,0 +1,78 @@
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================

"""ReverseSequence op"""
from mindspore.ops.op_info_register import op_info_register, AiCPURegOp, DataType
reverse_sequence_op_info = AiCPURegOp("ReverseSequence") \
.fusion_type("OPAQUE") \
.input(0, "x", "required") \
.input(1, "seq_lengths", "required") \
.output(0, "y", "required") \
.attr("seq_dim", "int") \
.attr("batch_dim", "int") \
.dtype_format(DataType.BOOL_Default, DataType.I32_Default, DataType.BOOL_Default) \
.dtype_format(DataType.I8_Default, DataType.I32_Default, DataType.I8_Default) \
.dtype_format(DataType.I16_Default, DataType.I32_Default, DataType.I16_Default) \
.dtype_format(DataType.I32_Default, DataType.I32_Default, DataType.I32_Default) \
.dtype_format(DataType.I64_Default, DataType.I32_Default, DataType.I64_Default) \
.dtype_format(DataType.U8_Default, DataType.I32_Default, DataType.U8_Default) \
.dtype_format(DataType.U16_Default, DataType.I32_Default, DataType.U16_Default) \
.dtype_format(DataType.U32_Default, DataType.I32_Default, DataType.U32_Default) \
.dtype_format(DataType.U64_Default, DataType.I32_Default, DataType.U64_Default) \
.dtype_format(DataType.F16_Default, DataType.I32_Default, DataType.F16_Default) \
.dtype_format(DataType.F32_Default, DataType.I32_Default, DataType.F32_Default) \
.dtype_format(DataType.F64_Default, DataType.I32_Default, DataType.F64_Default) \
.dtype_format(DataType.BOOL_NCHW, DataType.I32_NCHW, DataType.BOOL_NCHW) \
.dtype_format(DataType.I8_NCHW, DataType.I32_NCHW, DataType.I8_NCHW) \
.dtype_format(DataType.I16_NCHW, DataType.I32_NCHW, DataType.I16_NCHW) \
.dtype_format(DataType.I32_NCHW, DataType.I32_NCHW, DataType.I32_NCHW) \
.dtype_format(DataType.I64_NCHW, DataType.I32_NCHW, DataType.I64_NCHW) \
.dtype_format(DataType.U8_NCHW, DataType.I32_NCHW, DataType.U8_NCHW) \
.dtype_format(DataType.U16_NCHW, DataType.I32_NCHW, DataType.U16_NCHW) \
.dtype_format(DataType.U32_NCHW, DataType.I32_NCHW, DataType.U32_NCHW) \
.dtype_format(DataType.U64_NCHW, DataType.I32_NCHW, DataType.U64_NCHW) \
.dtype_format(DataType.F16_NCHW, DataType.I32_NCHW, DataType.F16_NCHW) \
.dtype_format(DataType.F32_NCHW, DataType.I32_NCHW, DataType.F32_NCHW) \
.dtype_format(DataType.F64_NCHW, DataType.I32_NCHW, DataType.F64_NCHW) \
.dtype_format(DataType.BOOL_Default, DataType.I64_Default, DataType.BOOL_Default) \
.dtype_format(DataType.I8_Default, DataType.I64_Default, DataType.I8_Default) \
.dtype_format(DataType.I16_Default, DataType.I64_Default, DataType.I16_Default) \
.dtype_format(DataType.I64_Default, DataType.I64_Default, DataType.I32_Default) \
.dtype_format(DataType.I64_Default, DataType.I64_Default, DataType.I64_Default) \
.dtype_format(DataType.U8_Default, DataType.I64_Default, DataType.U8_Default) \
.dtype_format(DataType.U16_Default, DataType.I64_Default, DataType.U16_Default) \
.dtype_format(DataType.U32_Default, DataType.I64_Default, DataType.U32_Default) \
.dtype_format(DataType.U64_Default, DataType.I64_Default, DataType.U64_Default) \
.dtype_format(DataType.F16_Default, DataType.I64_Default, DataType.F16_Default) \
.dtype_format(DataType.F32_Default, DataType.I64_Default, DataType.F32_Default) \
.dtype_format(DataType.F64_Default, DataType.I64_Default, DataType.F64_Default) \
.dtype_format(DataType.BOOL_NCHW, DataType.I64_NCHW, DataType.BOOL_NCHW) \
.dtype_format(DataType.I8_NCHW, DataType.I64_NCHW, DataType.I8_NCHW) \
.dtype_format(DataType.I16_NCHW, DataType.I64_NCHW, DataType.I16_NCHW) \
.dtype_format(DataType.I32_NCHW, DataType.I64_NCHW, DataType.I32_NCHW) \
.dtype_format(DataType.I64_NCHW, DataType.I64_NCHW, DataType.I64_NCHW) \
.dtype_format(DataType.U8_NCHW, DataType.I64_NCHW, DataType.U8_NCHW) \
.dtype_format(DataType.U16_NCHW, DataType.I64_NCHW, DataType.U16_NCHW) \
.dtype_format(DataType.U32_NCHW, DataType.I64_NCHW, DataType.U32_NCHW) \
.dtype_format(DataType.U64_NCHW, DataType.I64_NCHW, DataType.U64_NCHW) \
.dtype_format(DataType.F16_NCHW, DataType.I64_NCHW, DataType.F16_NCHW) \
.dtype_format(DataType.F32_NCHW, DataType.I64_NCHW, DataType.F32_NCHW) \
.dtype_format(DataType.F64_NCHW, DataType.I64_NCHW, DataType.F64_NCHW) \
.get_op_info()

@op_info_register(reverse_sequence_op_info)
def _reverse_sequence_aicpu():
"""ReverseSequence AiCPU register"""
return

+ 2
- 1
mindspore/ops/operations/__init__.py View File

@@ -30,7 +30,7 @@ from .array_ops import (Argmax, Argmin, Cast, Concat, Pack, Unpack,
Squeeze, StridedSlice, Tile, Squeeze, StridedSlice, Tile,
Transpose, TruncatedNormal, TupleToArray, UnsortedSegmentMin, Transpose, TruncatedNormal, TupleToArray, UnsortedSegmentMin,
UnsortedSegmentSum, SpaceToDepth, DepthToSpace, SpaceToBatch, BatchToSpace, UnsortedSegmentSum, SpaceToDepth, DepthToSpace, SpaceToBatch, BatchToSpace,
SpaceToBatchND, BatchToSpaceND)
SpaceToBatchND, BatchToSpaceND, ReverseSequence)
from .comm_ops import (AllGather, AllReduce, _AlltoAll, ReduceScatter, Broadcast, from .comm_ops import (AllGather, AllReduce, _AlltoAll, ReduceScatter, Broadcast,
_MirrorOperator, ReduceOp, _VirtualDataset, _MirrorOperator, ReduceOp, _VirtualDataset,
_VirtualDiv, _GetTensorSlice, _VirtualDiv, _GetTensorSlice,
@@ -278,6 +278,7 @@ __all__ = [
"ApplyCenteredRMSProp", "ApplyCenteredRMSProp",
"SpaceToBatchND", "SpaceToBatchND",
"BatchToSpaceND", "BatchToSpaceND",
"ReverseSequence",
"SquareSumAll", "SquareSumAll",
"BitwiseAnd", "BitwiseAnd",
"BitwiseOr", "BitwiseOr",


+ 51
- 0
mindspore/ops/operations/array_ops.py View File

@@ -2720,3 +2720,54 @@ class BatchToSpaceND(PrimitiveWithInfer):
f'block_shape_prod {block_shape_prod}') f'block_shape_prod {block_shape_prod}')
out_shape[0] = out_shape[0] // block_shape_prod out_shape[0] = out_shape[0] // block_shape_prod
return out_shape return out_shape


class ReverseSequence(PrimitiveWithInfer):
"""
Reverses variable length slices.

Note:
If the specified axis is a negative number, the index is counted
backward from the end and starts at 1.

Raises:
ValueError: If axis is not an integer or not in the valid range.
Args:
seq_dim (int): The dimension which is partially reversed. Required.
batch_dim (int): The dimension along which reversal is performed. Default: 0

Inputs:
- **x** (Tensor) - The input to reverse.
- **seq_lengths** (int) - Must be 1-D vector with types: int32, int64

Outputs:
Reversed tensor with the same shape and data type as x.

Examples:
>>> x = Tensor(np.array([[1, 2, 3], [4, 5, 6], [7, 8, 9]]), mindspore.float32)
>>> seq_lengths = Tensor(np.array([1, 2, 3]))
>>> reverse_sequence = P.ReverseSequence(seq_dim=1)
>>> output = reverse_sequence(x, seq_lengths)
"""

@prim_attr_register
def __init__(self, seq_dim, batch_dim=0):
"""init ReverseSequence"""
self.init_prim_io_names(inputs=['x', 'seq_lengths'], outputs=['y'])
validator.check_value_type("seq_dim", seq_dim, [int], self.name)
self.seq_dim_ = seq_dim
validator.check_value_type("batch_dim", batch_dim, [int], self.name)
self.batch_dim_ = batch_dim

def infer_shape(self, x, seq_lengths):
validator.check("seq_dim", self.seq_dim_, "x rank", len(x), Rel.LE, self.name)
validator.check("batch_dim", self.batch_dim_, "x rank", len(x), Rel.LE, self.name)
validator.check("batch_dim", self.batch_dim_, "seq_dim", self.seq_dim_, Rel.NE, self.name)
validator.check("seq_lengths rank", len(seq_lengths), "expected", 1, Rel.EQ, self.name)
validator.check("seq_lengths vector size", seq_lengths[0],
"input size along batch_dim", x[self.batch_dim_], Rel.EQ, self.name)
return x

def infer_dtype(self, x, seq_lengths):
validator.check_tensor_type_same({"seq_lengths_dtype": seq_lengths}, [mstype.int32, mstype.int64], self.name)
return x

Loading…
Cancel
Save