diff --git a/mindspore/ccsrc/backend/kernel_compiler/gpu/arrays/sequence_mask_gpu_kernel.cc b/mindspore/ccsrc/backend/kernel_compiler/gpu/arrays/sequence_mask_gpu_kernel.cc new file mode 100644 index 0000000000..c42927eab2 --- /dev/null +++ b/mindspore/ccsrc/backend/kernel_compiler/gpu/arrays/sequence_mask_gpu_kernel.cc @@ -0,0 +1,35 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include + +#include "backend/kernel_compiler/gpu/arrays/sequence_mask_gpu_kernel.h" + +namespace mindspore { +namespace kernel { + +// keep this as TWO but output is always bool, just in case framework can +// support passing optional dtype and then we can be identical to tf +MS_REG_GPU_KERNEL_TWO( + SequenceMask, + KernelAttr().AddInputAttr(kNumberTypeInt32).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeBool), + SequenceMaskGpuKernel, int32_t, bool) + +MS_REG_GPU_KERNEL_TWO( + SequenceMask, + KernelAttr().AddInputAttr(kNumberTypeInt64).AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeBool), + SequenceMaskGpuKernel, int64_t, bool) +} // namespace kernel +} // namespace mindspore diff --git a/mindspore/ccsrc/backend/kernel_compiler/gpu/arrays/sequence_mask_gpu_kernel.h b/mindspore/ccsrc/backend/kernel_compiler/gpu/arrays/sequence_mask_gpu_kernel.h new file mode 100644 index 0000000000..314d4e102e --- /dev/null +++ b/mindspore/ccsrc/backend/kernel_compiler/gpu/arrays/sequence_mask_gpu_kernel.h @@ -0,0 +1,101 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_SEQUENCE_MASK_GPU_KERNEL_H_ +#define MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_SEQUENCE_MASK_GPU_KERNEL_H_ + +#include "backend/kernel_compiler/gpu/cuda_impl/sequence_mask_impl.cuh" + +#include + +#include + +#include "backend/kernel_compiler/gpu/gpu_kernel.h" +#include "backend/kernel_compiler/gpu/gpu_kernel_factory.h" + +namespace mindspore { +namespace kernel { +template +class SequenceMaskGpuKernel : public GpuKernel { + public: + SequenceMaskGpuKernel() { ResetResource(); } + ~SequenceMaskGpuKernel() = default; + + const std::vector &GetInputSizeList() const override { return input_size_list_; } + const std::vector &GetOutputSizeList() const override { return output_size_list_; } + const std::vector &GetWorkspaceSizeList() const override { return workspace_size_list_; } + + bool Launch(const std::vector &inputs, const std::vector &workspace, + const std::vector &outputs, void *stream_ptr) override { + T *lengths_device_address = GetDeviceAddress(inputs, 0); + T *maxlen_device_address = GetDeviceAddress(inputs, 1); + S *output_device_address = GetDeviceAddress(outputs, 0); + + CalSequenceMask(lengths_device_address, maxlen_device_address, output_device_address, output_size_, + reinterpret_cast(stream_ptr)); + + return true; + } + + bool Init(const CNodePtr &kernel_node) override { + size_t input_count = AnfAlgo::GetInputTensorNum(kernel_node); + if (input_count != 2) { + MS_LOG(EXCEPTION) << input_count << " inputs were provided, but SequenceMaskGpuKernel expects 2."; + } + + input_shape_ = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 0); + for (const int &e : input_shape_) { + lengths_size_ *= e; + } + + std::vector inferred_output_shape = AnfAlgo::GetOutputInferShape(kernel_node, 0); + for (const size_t &e : inferred_output_shape) { + output_size_ *= e; + } + + InitSizeLists(); + + return true; + } + + void ResetResource() noexcept override { + output_size_ = 1; + lengths_size_ = 1; + input_size_list_.clear(); + output_size_list_.clear(); + workspace_size_list_.clear(); + } + + protected: + void InitSizeLists() override { + input_size_list_.push_back(lengths_size_ * sizeof(T)); + input_size_list_.push_back(sizeof(T)); + output_size_list_.push_back(output_size_); + } + + private: + std::vector input_shape_; + size_t lengths_size_; + size_t output_size_; + + std::vector input_size_list_; + std::vector output_size_list_; + std::vector workspace_size_list_; +}; +} // namespace kernel +} // namespace mindspore + +#endif // MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_SEQUENCE_MASK_GPU_KERNEL_H_ diff --git a/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/sequence_mask_impl.cu b/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/sequence_mask_impl.cu new file mode 100644 index 0000000000..1bdc72b4ef --- /dev/null +++ b/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/sequence_mask_impl.cu @@ -0,0 +1,50 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include + +#include "sequence_mask_impl.cuh" +#include "runtime/device/gpu/cuda_common.h" + +__global__ void ValidateArgs(int *maxlen, const int lengths_size, const int max_output_size) { + int maxlen_value = *maxlen; + if (maxlen_value < 0 || lengths_size * maxlen_value > max_output_size) { + asm("trap;"); + } +} + +template +__global__ void SequenceMask( + const T *input, T *maxlen, S *output, const size_t output_size) { + T maxlen_value = *maxlen; + + for (size_t gt_id = blockIdx.x * blockDim.x + threadIdx.x; gt_id < output_size; gt_id += gridDim.x * blockDim.x) { + T mask_comparison_value = gt_id % maxlen_value; + T input_comparison_index = (gt_id - mask_comparison_value) / maxlen_value; + S result = mask_comparison_value < input[input_comparison_index]; + output[gt_id] = result; + } +} + +template +void CalSequenceMask(const T *lengths, T *maxlen, S *output, const size_t output_size, cudaStream_t cuda_stream) { + SequenceMask<<>>(lengths, maxlen, output, output_size); +} + +template void CalSequenceMask(const int *lengths, int *maxlen, bool *output, const size_t output_size, + cudaStream_t cuda_stream); + +template void CalSequenceMask(const int64_t *lengths, int64_t *maxlen, bool *output, + const size_t output_size, cudaStream_t cuda_stream); diff --git a/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/sequence_mask_impl.cuh b/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/sequence_mask_impl.cuh new file mode 100644 index 0000000000..241c0134d1 --- /dev/null +++ b/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/sequence_mask_impl.cuh @@ -0,0 +1,25 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_CUDA_IMPL_SEQUENCE_MASK_CUH_ +#define MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_CUDA_IMPL_SEQUENCE_MASK_CUH_ + +#include + +template +void CalSequenceMask(const T *lengths, T *maxlen, S *output, const size_t output_size, cudaStream_t cuda_stream); + +#endif // MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_CUDA_IMPL_SEQUENCE_MASK_CUH_ diff --git a/mindspore/core/abstract/infer_functions.h b/mindspore/core/abstract/infer_functions.h index d88a4d872c..de1179c905 100644 --- a/mindspore/core/abstract/infer_functions.h +++ b/mindspore/core/abstract/infer_functions.h @@ -265,6 +265,9 @@ AbstractBasePtr InferImplPad(const AnalysisEnginePtr &, const PrimitivePtr &prim const AbstractBasePtrList &args_spec_list); AbstractBasePtr InferImplSplit(const AnalysisEnginePtr &, const PrimitivePtr &primitive, const AbstractBasePtrList &args_spec_list); +AbstractBasePtr InferImplSequenceMask(const AnalysisEnginePtr &, const PrimitivePtr &primitive, + const AbstractBasePtrList &args_spec_list); + template AbstractBasePtr InferTupleOrListOrDictLen(const std::string &op_name, const AbstractBasePtrList &args_spec_list) { // Inputs: a tuple or list or dict. diff --git a/mindspore/core/abstract/prim_arrays.cc b/mindspore/core/abstract/prim_arrays.cc index 3a86daaace..b39c93cfba 100644 --- a/mindspore/core/abstract/prim_arrays.cc +++ b/mindspore/core/abstract/prim_arrays.cc @@ -839,5 +839,56 @@ AbstractBasePtr InferImplSplit(const AnalysisEnginePtr &, const PrimitivePtr &pr } return std::make_shared(output_list); } + +AbstractBasePtr InferImplSequenceMask(const AnalysisEnginePtr &, const PrimitivePtr &primitive, + const AbstractBasePtrList &args_spec_list) { + const std::string &op_name = primitive->name(); + CheckArgsSize(op_name, args_spec_list, 2); + + AbstractTensorPtr lengths = CheckArg(op_name, args_spec_list, 0); + (void)CheckTensorDType(lengths, {kInt32, kInt64}, "Input 1 (lengths) for SequenceMask should be one of: %s"); + + int64_t maxlen_value = 0; + + if (args_spec_list[1]->isa()) { + AbstractScalarPtr maxlen = CheckArg(op_name, args_spec_list, 1); + (void)CheckScalarType(maxlen, {kInt32, kInt64}, "Input 0 (maxlen) for SequenceMask should be one of: %s"); + + TypePtr maxlen_type = nullptr; + maxlen_type = maxlen->GetTypeTrack(); + MS_EXCEPTION_IF_NULL(maxlen_type); + + if (maxlen_type->type_id() == TypeId::kNumberTypeInt32) { + maxlen_value = static_cast(GetValue(maxlen->BuildValue())); + } else if (maxlen_type->type_id() == TypeId::kNumberTypeInt64) { + maxlen_value = GetValue(maxlen->BuildValue()); + } + } else if (args_spec_list[1]->isa()) { + auto maxlen_tensor_ptr = args_spec_list[1]->cast(); + MS_EXCEPTION_IF_NULL(maxlen_tensor_ptr); + auto maxlen_value_ptr = maxlen_tensor_ptr->BuildValue(); + MS_EXCEPTION_IF_NULL(maxlen_value_ptr); + auto maxlen_tensor = maxlen_value_ptr->cast(); + MS_EXCEPTION_IF_NULL(maxlen_tensor); + maxlen_value = *static_cast(maxlen_tensor->data_c()); + } + + ShapeVector lengths_shape = lengths->shape()->shape(); + ShapeVector lengths_shape_min = lengths->shape()->min_shape(); + if (lengths_shape_min.empty()) { + lengths_shape_min = lengths_shape; + } + ShapeVector lengths_shape_max = lengths->shape()->max_shape(); + if (lengths_shape_max.empty()) { + lengths_shape_max = lengths_shape; + } + + lengths_shape.push_back(maxlen_value); + lengths_shape_min.push_back(maxlen_value); + lengths_shape_max.push_back(maxlen_value); + + ShapePtr output_shape = std::make_shared(lengths_shape, lengths_shape_min, lengths_shape_max); + return std::make_shared(kBool, output_shape); +} } // namespace abstract } // namespace mindspore diff --git a/mindspore/core/abstract/primitive_infer_map.cc b/mindspore/core/abstract/primitive_infer_map.cc index 6f92484097..f8dff15bfd 100644 --- a/mindspore/core/abstract/primitive_infer_map.cc +++ b/mindspore/core/abstract/primitive_infer_map.cc @@ -72,6 +72,7 @@ PrimitiveEvalImplMap &GetPrimitiveToEvalImplMap() { {prim::kPrimTranspose, {InferImplTranspose, true}}, {prim::kPrimReshape, {InferImplReshape, true}}, {prim::kPrimSplit, {InferImplSplit, true}}, + {prim::kPrimSequenceMask, {InferImplSequenceMask, true}}, // Structure {prim::kPrimMakeTuple, {InferImplMakeTuple, true}}, {prim::kPrimMakeList, {InferImplMakeList, true}}, diff --git a/mindspore/core/base/core_ops.h b/mindspore/core/base/core_ops.h index 2ec7b17a4f..d877411806 100644 --- a/mindspore/core/base/core_ops.h +++ b/mindspore/core/base/core_ops.h @@ -119,6 +119,7 @@ inline const PrimitivePtr kPrimDynamicGRUV2Grad = std::make_shared("D inline const PrimitivePtr kPrimScatterAdd = std::make_shared("ScatterAdd"); inline const PrimitivePtr kPrimScatterUpdate = std::make_shared("ScatterUpdate"); inline const PrimitivePtr kPrimSplit = std::make_shared("Split"); +inline const PrimitivePtr kPrimSequenceMask = std::make_shared("SequenceMask"); // NN inline const PrimitivePtr kPrimFlatten = std::make_shared("Flatten"); diff --git a/mindspore/ops/operations/__init__.py b/mindspore/ops/operations/__init__.py index b6f63f1146..8a1cbc26ad 100644 --- a/mindspore/ops/operations/__init__.py +++ b/mindspore/ops/operations/__init__.py @@ -22,7 +22,7 @@ A collection of operators to build neural networks or to compute functions. from .image_ops import (CropAndResize) from .array_ops import (Argmax, Argmin, Cast, Concat, Pack, Unpack, Diag, DiagPart, DType, ExpandDims, Eye, - Fill, Ones, Zeros, SequenceMask, GatherNd, GatherV2, SparseGatherV2, InvertPermutation, + Fill, Ones, Zeros, GatherNd, GatherV2, SparseGatherV2, InvertPermutation, IsInstance, IsSubClass, ArgMaxWithValue, OnesLike, ZerosLike, Rank, Reshape, ResizeNearestNeighbor, ArgMinWithValue, Meshgrid, SameTypeShape, ScatterAdd, ScatterSub, ScatterMul, ScatterDiv, ScatterMax, ScatterMin, @@ -33,7 +33,7 @@ from .array_ops import (Argmax, Argmin, Cast, Concat, Pack, Unpack, Transpose, TruncatedNormal, TupleToArray, UnsortedSegmentMin, UnsortedSegmentMax, UnsortedSegmentProd, UnsortedSegmentSum, SpaceToDepth, DepthToSpace, SpaceToBatch, BatchToSpace, SpaceToBatchND, BatchToSpaceND, BroadcastTo, InplaceUpdate, ReverseSequence, EmbeddingLookup, - Unique, GatherD, Identity) + Unique, GatherD, Identity, SequenceMask) from .comm_ops import (AllGather, AllReduce, _AlltoAll, AllSwap, ReduceScatter, Broadcast, _MirrorOperator, ReduceOp, _VirtualDataset, _VirtualDiv, _GetTensorSlice, Send, Receive, @@ -394,6 +394,7 @@ __all__ = [ "Pull", "ReLUV2", "SparseToDense", + "SequenceMask", ] __all__.sort() diff --git a/mindspore/ops/operations/array_ops.py b/mindspore/ops/operations/array_ops.py index 245cbae48b..1241005510 100644 --- a/mindspore/ops/operations/array_ops.py +++ b/mindspore/ops/operations/array_ops.py @@ -1216,68 +1216,6 @@ class Zeros(PrimitiveWithInfer): return out -class SequenceMask(PrimitiveWithInfer): - r""" - Generates sequence mask according to input lengths. - - Creates a mask tensor which retains the first N elements in tensor by setting the values - to be True or one. The rest values in mask are set to False or zero. - - Args: - max_length (int): Nonnegative integer, size of the last dimension in mask. Default: None. - - Inputs: - - **lengths** (Union[tuple[int], list[int]]) - Defines the first N elements that are retained. - Only constant value is allowed. - - **dtype** (mindspore.dtype) - The specified type of output tensor. Only constant value is allowed. - - Outputs: - Tensor. - If max_length is set, the shape of the output is (lengths.shape, max_length). - If max_length is not set and the biggest value in lengths is x. Then, the shape of - the output is (lengths.shape, x). - - Supported Platforms: - ``Ascend`` ``GPU`` ``CPU`` - - Examples: - >>> from mindspore.ops import operations as P - >>> sequence_mask = P.SequenceMask() - >>> mask = sequence_mask([2, 2, 4], mindspore.int32) - >>> print(mask) - [[1, 1, 0, 0], - [1, 1, 0, 0], - [1, 1, 1, 1]] - - """ - - @prim_attr_register - def __init__(self): - """Initialize SequenceMask""" - - def __infer__(self, lengths, dtype, max_length=None): - validator.check_value_type("shape", lengths['value'], [tuple, list], self.name) - valid_types = [mstype.bool_, mstype.int8, mstype.int16, mstype.int32, mstype.int64, - mstype.uint8, mstype.uint16, mstype.uint32, mstype.uint64, - mstype.float16, mstype.float32, mstype.float64] - validator.check_subclass("dtype", dtype['value'], valid_types, self.name) - nptype = mstype.dtype_to_nptype(dtype['value']) - if max_length is None: - max_length = np.max(lengths['value']) - else: - validator.check_non_negative_int(max_length['value']) - max_length = max_length['value'] - row_vector = np.arange(0, max_length) - col_matrix = np.expand_dims(lengths['value'], -1) - result = (row_vector < col_matrix).astype(nptype) - out = { - 'value': Tensor(result), - 'shape': result.shape, - 'dtype': dtype['value'] - } - return out - - class OnesLike(PrimitiveWithInfer): """ Creates a new tensor. The values of all elements are 1. @@ -4648,3 +4586,47 @@ class Identity(PrimitiveWithInfer): 'dtype': x['dtype'], 'value': None} return out + + +class SequenceMask(PrimitiveWithCheck): + """ + Returns a mask tensor representing the first N positions of each cell. + + If lengths has shape [d_1, d_2, ..., d_n], then the resulting tensor mask has type dtype and shape + [d_1, d_2, ..., d_n, maxlen], with mask[i_1, i_2, ..., i_n, j] = (j < lengths[i_1, i_2, ..., i_n]) + + Inputs: + - **lengths** (Tensor) - Tensor to calculate the mask for. All values in this tensor must be + less than `maxlen`. Must be type int32 or int64. + + - **maxlen** (int) - size of the last dimension of returned tensor. Must be positive and same + tyupe as elements in `lengths`. + + Outputs: + One mask tensor of shape lengths.shape + (maxlen,). + + Supported Platforms: + ``GPU`` + + Examples: + >>> x = Tensor(np.array([[1, 3], [2, 0]]) + >>> sequence_mask = P.SequenceMask() + >>> output = sequence_mask(x, 3) + >>> print(output) + [[[True, False, False], + [True, True, True]], + [[True, True, False], + [False, False, False]]] + """ + + @prim_attr_register + def __init__(self): + self.init_prim_io_names(inputs=["lengths", "maxlen"], outputs=["mask"]) + + def check_shape(self, lengths_shape, maxlen_shape): + validator.check("lengths_shape", len(lengths_shape), "", 0, Rel.GT, self.name) + validator.check("maxlen_shape", len(maxlen_shape), "", 0, Rel.EQ, self.name) + + def check_dtype(self, lengths_dtype, maxlen_dtype): + validator.check_subclass("lengths_dtype", lengths_dtype, mstype.tensor, self.name) + validator.check_subclass("maxlen", maxlen_dtype, mstype.number, self.name) diff --git a/tests/st/ops/gpu/test_sequence_mask_op.py b/tests/st/ops/gpu/test_sequence_mask_op.py new file mode 100644 index 0000000000..32262dd3af --- /dev/null +++ b/tests/st/ops/gpu/test_sequence_mask_op.py @@ -0,0 +1,117 @@ +import numpy as np +import pytest + +from mindspore import Tensor +from mindspore.ops import operations as P +from mindspore.ops.operations import _inner_ops as inner +import mindspore.nn as nn +import mindspore.context as context + +def sequence_mask(x, maxlen): + sequence_mask_op = P.SequenceMask() + return sequence_mask_op(Tensor(x.astype(np.int32)), maxlen) + +@pytest.mark.level0 +@pytest.mark.platform_x86_gpu_training +@pytest.mark.env_onecard +def test_sequence_mask_1d(): + context.set_context(mode=context.PYNATIVE_MODE, device_target="GPU") + a = np.array([2, 3, 1]) + + maxlen = 4 + ms_out = sequence_mask(a, maxlen) + expected_out = Tensor(np.array([[True, True, False, False], + [True, True, True, False], + [True, False, False, False]])) + np.testing.assert_array_equal(expected_out.asnumpy(), ms_out.asnumpy()) + +@pytest.mark.level0 +@pytest.mark.platform_x86_gpu_training +@pytest.mark.env_onecard +def test_sequence_mask_2d(): + context.set_context(mode=context.PYNATIVE_MODE, device_target="GPU") + a = np.array([[0, 1, 3, 2], [1, 4, 4, 2]]) + + maxlen = 6 + ms_out = sequence_mask(a, maxlen) + expected_out = Tensor(np.array([[[False, False, False, False, False, False], + [True, False, False, False, False, False], + [True, True, True, False, False, False], + [True, True, False, False, False, False]], + [[True, False, False, False, False, False], + [True, True, True, True, False, False], + [True, True, True, True, False, False], + [True, True, False, False, False, False]]])) + np.testing.assert_array_equal(expected_out.asnumpy(), ms_out.asnumpy()) + +@pytest.mark.level0 +@pytest.mark.platform_x86_gpu_training +@pytest.mark.env_onecard +def test_sequence_mask_3d(): + context.set_context(mode=context.PYNATIVE_MODE, device_target="GPU") + a = np.array([[[2, 2], [1, 1]], + [[2, 0], [2, 1]], + [[0, 0], [0, 0]]]) + + maxlen = 2 + ms_out = sequence_mask(a, maxlen) + expected_out = Tensor(np.array([[[[True, True], [True, True]], [[True, False], [True, False]]], + [[[True, True], [False, False]], [[True, True], [True, False]]], + [[[False, False], [False, False]], [[False, False], [False, False]]]])) + + np.testing.assert_array_equal(expected_out.asnumpy(), ms_out.asnumpy()) + +@pytest.mark.level0 +@pytest.mark.platform_x86_gpu_training +@pytest.mark.env_onecard +def test_sequence_mask_maxlen_1(): + context.set_context(mode=context.PYNATIVE_MODE, device_target="GPU") + a = np.array([[[0, 1], [1, 1]], + [[1, 0], [1, 1]], + [[0, 1], [0, 1]]]) + + maxlen = 1 + ms_out = sequence_mask(a, maxlen) + expected_out = Tensor(np.array([[[[False], [True]], [[True], [True,]]], + [[[True], [False]], [[True], [True]]], + [[[False], [True]], [[False], [True]]]])) + + np.testing.assert_array_equal(expected_out.asnumpy(), ms_out.asnumpy()) + +@pytest.mark.level0 +@pytest.mark.platform_x86_gpu_training +@pytest.mark.env_onecard +def test_sequence_mask_dynamic(): + class SequenceMaskDynamicNet(nn.Cell): + def __init__(self, maxlen): + super(SequenceMaskDynamicNet, self).__init__() + self.maxlen = maxlen + self.convert_to_dynamic_shape = inner.GpuConvertToDynamicShape() + self.sequence_mask = P.SequenceMask() + + def construct(self, x): + converted_to_dynamic_shape = self.convert_to_dynamic_shape(x) + return self.sequence_mask(converted_to_dynamic_shape, self.maxlen) + + context.set_context(mode=context.GRAPH_MODE, device_target="GPU") + + sequence_mask_net = SequenceMaskDynamicNet(4) + + a = Tensor(np.array([0, 1, 0, 2, 0, 5])) + ms_out = sequence_mask_net(a) + expected_out = Tensor(np.array([[False, False, False, False], + [True, False, False, False], + [False, False, False, False], + [True, True, False, False], + [False, False, False, False], + [True, True, True, True]])) + np.testing.assert_array_equal(expected_out.asnumpy(), ms_out.asnumpy()) + + a = Tensor(np.array([[4, 3, 0], [0, 1, 3]])) + ms_out = sequence_mask_net(a) + expected_out = Tensor(np.array([[[True, True, True, True], + [True, True, True, False], + [False, False, False, False]], + [[False, False, False, False], + [True, False, False, False], + [True, True, True, False]]])) diff --git a/tests/ut/python/ops/test_array_ops.py b/tests/ut/python/ops/test_array_ops.py index b0808828af..3992508265 100644 --- a/tests/ut/python/ops/test_array_ops.py +++ b/tests/ut/python/ops/test_array_ops.py @@ -42,28 +42,6 @@ def test_expand_dims(): assert output.asnumpy().shape == (1, 2, 2) -def test_sequence_mask(): - list_ = [2, 2, 4] - sequence_mask = P.SequenceMask() - mask1 = sequence_mask(list_, mstype.int32) - mask2 = sequence_mask(list_, mstype.int32, 5) - assert mask1.shape == (3, 4) - assert mask1.dtype == mstype.int32 - assert mask2.shape == (3, 5) - assert mask2.dtype == mstype.int32 - - -def test_sequence_mask_1(): - list_ = [[2, 2, 4], [3, 4, 4]] - sequence_mask = P.SequenceMask() - mask1 = sequence_mask(list_, mstype.bool_) - mask2 = sequence_mask(list_, mstype.bool_, 5) - assert mask1.shape == (2, 3, 4) - assert mask1.dtype == mstype.bool_ - assert mask2.shape == (2, 3, 5) - assert mask2.dtype == mstype.bool_ - - def test_cast(): input_np = np.random.randn(2, 3, 4, 5).astype(np.float32) input_x = Tensor(input_np)