From: @peilin-wang Reviewed-by: @tom__chen Signed-off-by:tags/v1.1.0
| @@ -0,0 +1,35 @@ | |||
| /** | |||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #include <cstdint> | |||
| #include "backend/kernel_compiler/gpu/arrays/sequence_mask_gpu_kernel.h" | |||
| namespace mindspore { | |||
| namespace kernel { | |||
| // keep this as TWO but output is always bool, just in case framework can | |||
| // support passing optional dtype and then we can be identical to tf | |||
| MS_REG_GPU_KERNEL_TWO( | |||
| SequenceMask, | |||
| KernelAttr().AddInputAttr(kNumberTypeInt32).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeBool), | |||
| SequenceMaskGpuKernel, int32_t, bool) | |||
| MS_REG_GPU_KERNEL_TWO( | |||
| SequenceMask, | |||
| KernelAttr().AddInputAttr(kNumberTypeInt64).AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeBool), | |||
| SequenceMaskGpuKernel, int64_t, bool) | |||
| } // namespace kernel | |||
| } // namespace mindspore | |||
| @@ -0,0 +1,101 @@ | |||
| /** | |||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #ifndef MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_SEQUENCE_MASK_GPU_KERNEL_H_ | |||
| #define MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_SEQUENCE_MASK_GPU_KERNEL_H_ | |||
| #include "backend/kernel_compiler/gpu/cuda_impl/sequence_mask_impl.cuh" | |||
| #include <cuda_runtime.h> | |||
| #include <vector> | |||
| #include "backend/kernel_compiler/gpu/gpu_kernel.h" | |||
| #include "backend/kernel_compiler/gpu/gpu_kernel_factory.h" | |||
| namespace mindspore { | |||
| namespace kernel { | |||
| template <typename T, typename S> | |||
| class SequenceMaskGpuKernel : public GpuKernel { | |||
| public: | |||
| SequenceMaskGpuKernel() { ResetResource(); } | |||
| ~SequenceMaskGpuKernel() = default; | |||
| const std::vector<size_t> &GetInputSizeList() const override { return input_size_list_; } | |||
| const std::vector<size_t> &GetOutputSizeList() const override { return output_size_list_; } | |||
| const std::vector<size_t> &GetWorkspaceSizeList() const override { return workspace_size_list_; } | |||
| bool Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &workspace, | |||
| const std::vector<AddressPtr> &outputs, void *stream_ptr) override { | |||
| T *lengths_device_address = GetDeviceAddress<T>(inputs, 0); | |||
| T *maxlen_device_address = GetDeviceAddress<T>(inputs, 1); | |||
| S *output_device_address = GetDeviceAddress<S>(outputs, 0); | |||
| CalSequenceMask(lengths_device_address, maxlen_device_address, output_device_address, output_size_, | |||
| reinterpret_cast<cudaStream_t>(stream_ptr)); | |||
| return true; | |||
| } | |||
| bool Init(const CNodePtr &kernel_node) override { | |||
| size_t input_count = AnfAlgo::GetInputTensorNum(kernel_node); | |||
| if (input_count != 2) { | |||
| MS_LOG(EXCEPTION) << input_count << " inputs were provided, but SequenceMaskGpuKernel expects 2."; | |||
| } | |||
| input_shape_ = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 0); | |||
| for (const int &e : input_shape_) { | |||
| lengths_size_ *= e; | |||
| } | |||
| std::vector<size_t> inferred_output_shape = AnfAlgo::GetOutputInferShape(kernel_node, 0); | |||
| for (const size_t &e : inferred_output_shape) { | |||
| output_size_ *= e; | |||
| } | |||
| InitSizeLists(); | |||
| return true; | |||
| } | |||
| void ResetResource() noexcept override { | |||
| output_size_ = 1; | |||
| lengths_size_ = 1; | |||
| input_size_list_.clear(); | |||
| output_size_list_.clear(); | |||
| workspace_size_list_.clear(); | |||
| } | |||
| protected: | |||
| void InitSizeLists() override { | |||
| input_size_list_.push_back(lengths_size_ * sizeof(T)); | |||
| input_size_list_.push_back(sizeof(T)); | |||
| output_size_list_.push_back(output_size_); | |||
| } | |||
| private: | |||
| std::vector<size_t> input_shape_; | |||
| size_t lengths_size_; | |||
| size_t output_size_; | |||
| std::vector<size_t> input_size_list_; | |||
| std::vector<size_t> output_size_list_; | |||
| std::vector<size_t> workspace_size_list_; | |||
| }; | |||
| } // namespace kernel | |||
| } // namespace mindspore | |||
| #endif // MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_SEQUENCE_MASK_GPU_KERNEL_H_ | |||
| @@ -0,0 +1,50 @@ | |||
| /** | |||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #include <cuda_runtime.h> | |||
| #include "sequence_mask_impl.cuh" | |||
| #include "runtime/device/gpu/cuda_common.h" | |||
| __global__ void ValidateArgs(int *maxlen, const int lengths_size, const int max_output_size) { | |||
| int maxlen_value = *maxlen; | |||
| if (maxlen_value < 0 || lengths_size * maxlen_value > max_output_size) { | |||
| asm("trap;"); | |||
| } | |||
| } | |||
| template <typename T, typename S> | |||
| __global__ void SequenceMask( | |||
| const T *input, T *maxlen, S *output, const size_t output_size) { | |||
| T maxlen_value = *maxlen; | |||
| for (size_t gt_id = blockIdx.x * blockDim.x + threadIdx.x; gt_id < output_size; gt_id += gridDim.x * blockDim.x) { | |||
| T mask_comparison_value = gt_id % maxlen_value; | |||
| T input_comparison_index = (gt_id - mask_comparison_value) / maxlen_value; | |||
| S result = mask_comparison_value < input[input_comparison_index]; | |||
| output[gt_id] = result; | |||
| } | |||
| } | |||
| template <typename T, typename S> | |||
| void CalSequenceMask(const T *lengths, T *maxlen, S *output, const size_t output_size, cudaStream_t cuda_stream) { | |||
| SequenceMask<<<GET_BLOCKS(output_size), GET_THREADS, 0, cuda_stream>>>(lengths, maxlen, output, output_size); | |||
| } | |||
| template void CalSequenceMask<int, bool>(const int *lengths, int *maxlen, bool *output, const size_t output_size, | |||
| cudaStream_t cuda_stream); | |||
| template void CalSequenceMask<int64_t, bool>(const int64_t *lengths, int64_t *maxlen, bool *output, | |||
| const size_t output_size, cudaStream_t cuda_stream); | |||
| @@ -0,0 +1,25 @@ | |||
| /** | |||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #ifndef MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_CUDA_IMPL_SEQUENCE_MASK_CUH_ | |||
| #define MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_CUDA_IMPL_SEQUENCE_MASK_CUH_ | |||
| #include <cuda_runtime.h> | |||
| template <typename T, typename S> | |||
| void CalSequenceMask(const T *lengths, T *maxlen, S *output, const size_t output_size, cudaStream_t cuda_stream); | |||
| #endif // MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_CUDA_IMPL_SEQUENCE_MASK_CUH_ | |||
| @@ -265,6 +265,9 @@ AbstractBasePtr InferImplPad(const AnalysisEnginePtr &, const PrimitivePtr &prim | |||
| const AbstractBasePtrList &args_spec_list); | |||
| AbstractBasePtr InferImplSplit(const AnalysisEnginePtr &, const PrimitivePtr &primitive, | |||
| const AbstractBasePtrList &args_spec_list); | |||
| AbstractBasePtr InferImplSequenceMask(const AnalysisEnginePtr &, const PrimitivePtr &primitive, | |||
| const AbstractBasePtrList &args_spec_list); | |||
| template <typename T> | |||
| AbstractBasePtr InferTupleOrListOrDictLen(const std::string &op_name, const AbstractBasePtrList &args_spec_list) { | |||
| // Inputs: a tuple or list or dict. | |||
| @@ -839,5 +839,56 @@ AbstractBasePtr InferImplSplit(const AnalysisEnginePtr &, const PrimitivePtr &pr | |||
| } | |||
| return std::make_shared<AbstractTuple>(output_list); | |||
| } | |||
| AbstractBasePtr InferImplSequenceMask(const AnalysisEnginePtr &, const PrimitivePtr &primitive, | |||
| const AbstractBasePtrList &args_spec_list) { | |||
| const std::string &op_name = primitive->name(); | |||
| CheckArgsSize(op_name, args_spec_list, 2); | |||
| AbstractTensorPtr lengths = CheckArg<AbstractTensor>(op_name, args_spec_list, 0); | |||
| (void)CheckTensorDType(lengths, {kInt32, kInt64}, "Input 1 (lengths) for SequenceMask should be one of: %s"); | |||
| int64_t maxlen_value = 0; | |||
| if (args_spec_list[1]->isa<AbstractScalar>()) { | |||
| AbstractScalarPtr maxlen = CheckArg<AbstractScalar>(op_name, args_spec_list, 1); | |||
| (void)CheckScalarType(maxlen, {kInt32, kInt64}, "Input 0 (maxlen) for SequenceMask should be one of: %s"); | |||
| TypePtr maxlen_type = nullptr; | |||
| maxlen_type = maxlen->GetTypeTrack(); | |||
| MS_EXCEPTION_IF_NULL(maxlen_type); | |||
| if (maxlen_type->type_id() == TypeId::kNumberTypeInt32) { | |||
| maxlen_value = static_cast<int64_t>(GetValue<int32_t>(maxlen->BuildValue())); | |||
| } else if (maxlen_type->type_id() == TypeId::kNumberTypeInt64) { | |||
| maxlen_value = GetValue<int64_t>(maxlen->BuildValue()); | |||
| } | |||
| } else if (args_spec_list[1]->isa<AbstractTensor>()) { | |||
| auto maxlen_tensor_ptr = args_spec_list[1]->cast<AbstractTensorPtr>(); | |||
| MS_EXCEPTION_IF_NULL(maxlen_tensor_ptr); | |||
| auto maxlen_value_ptr = maxlen_tensor_ptr->BuildValue(); | |||
| MS_EXCEPTION_IF_NULL(maxlen_value_ptr); | |||
| auto maxlen_tensor = maxlen_value_ptr->cast<tensor::TensorPtr>(); | |||
| MS_EXCEPTION_IF_NULL(maxlen_tensor); | |||
| maxlen_value = *static_cast<int64_t *>(maxlen_tensor->data_c()); | |||
| } | |||
| ShapeVector lengths_shape = lengths->shape()->shape(); | |||
| ShapeVector lengths_shape_min = lengths->shape()->min_shape(); | |||
| if (lengths_shape_min.empty()) { | |||
| lengths_shape_min = lengths_shape; | |||
| } | |||
| ShapeVector lengths_shape_max = lengths->shape()->max_shape(); | |||
| if (lengths_shape_max.empty()) { | |||
| lengths_shape_max = lengths_shape; | |||
| } | |||
| lengths_shape.push_back(maxlen_value); | |||
| lengths_shape_min.push_back(maxlen_value); | |||
| lengths_shape_max.push_back(maxlen_value); | |||
| ShapePtr output_shape = std::make_shared<Shape>(lengths_shape, lengths_shape_min, lengths_shape_max); | |||
| return std::make_shared<AbstractTensor>(kBool, output_shape); | |||
| } | |||
| } // namespace abstract | |||
| } // namespace mindspore | |||
| @@ -72,6 +72,7 @@ PrimitiveEvalImplMap &GetPrimitiveToEvalImplMap() { | |||
| {prim::kPrimTranspose, {InferImplTranspose, true}}, | |||
| {prim::kPrimReshape, {InferImplReshape, true}}, | |||
| {prim::kPrimSplit, {InferImplSplit, true}}, | |||
| {prim::kPrimSequenceMask, {InferImplSequenceMask, true}}, | |||
| // Structure | |||
| {prim::kPrimMakeTuple, {InferImplMakeTuple, true}}, | |||
| {prim::kPrimMakeList, {InferImplMakeList, true}}, | |||
| @@ -119,6 +119,7 @@ inline const PrimitivePtr kPrimDynamicGRUV2Grad = std::make_shared<Primitive>("D | |||
| inline const PrimitivePtr kPrimScatterAdd = std::make_shared<Primitive>("ScatterAdd"); | |||
| inline const PrimitivePtr kPrimScatterUpdate = std::make_shared<Primitive>("ScatterUpdate"); | |||
| inline const PrimitivePtr kPrimSplit = std::make_shared<Primitive>("Split"); | |||
| inline const PrimitivePtr kPrimSequenceMask = std::make_shared<Primitive>("SequenceMask"); | |||
| // NN | |||
| inline const PrimitivePtr kPrimFlatten = std::make_shared<Primitive>("Flatten"); | |||
| @@ -22,7 +22,7 @@ A collection of operators to build neural networks or to compute functions. | |||
| from .image_ops import (CropAndResize) | |||
| from .array_ops import (Argmax, Argmin, Cast, Concat, Pack, Unpack, | |||
| Diag, DiagPart, DType, ExpandDims, Eye, | |||
| Fill, Ones, Zeros, SequenceMask, GatherNd, GatherV2, SparseGatherV2, InvertPermutation, | |||
| Fill, Ones, Zeros, GatherNd, GatherV2, SparseGatherV2, InvertPermutation, | |||
| IsInstance, IsSubClass, ArgMaxWithValue, OnesLike, ZerosLike, | |||
| Rank, Reshape, ResizeNearestNeighbor, ArgMinWithValue, Meshgrid, | |||
| SameTypeShape, ScatterAdd, ScatterSub, ScatterMul, ScatterDiv, ScatterMax, ScatterMin, | |||
| @@ -33,7 +33,7 @@ from .array_ops import (Argmax, Argmin, Cast, Concat, Pack, Unpack, | |||
| Transpose, TruncatedNormal, TupleToArray, UnsortedSegmentMin, UnsortedSegmentMax, | |||
| UnsortedSegmentProd, UnsortedSegmentSum, SpaceToDepth, DepthToSpace, SpaceToBatch, BatchToSpace, | |||
| SpaceToBatchND, BatchToSpaceND, BroadcastTo, InplaceUpdate, ReverseSequence, EmbeddingLookup, | |||
| Unique, GatherD, Identity) | |||
| Unique, GatherD, Identity, SequenceMask) | |||
| from .comm_ops import (AllGather, AllReduce, _AlltoAll, AllSwap, ReduceScatter, Broadcast, | |||
| _MirrorOperator, ReduceOp, _VirtualDataset, | |||
| _VirtualDiv, _GetTensorSlice, Send, Receive, | |||
| @@ -394,6 +394,7 @@ __all__ = [ | |||
| "Pull", | |||
| "ReLUV2", | |||
| "SparseToDense", | |||
| "SequenceMask", | |||
| ] | |||
| __all__.sort() | |||
| @@ -1216,68 +1216,6 @@ class Zeros(PrimitiveWithInfer): | |||
| return out | |||
| class SequenceMask(PrimitiveWithInfer): | |||
| r""" | |||
| Generates sequence mask according to input lengths. | |||
| Creates a mask tensor which retains the first N elements in tensor by setting the values | |||
| to be True or one. The rest values in mask are set to False or zero. | |||
| Args: | |||
| max_length (int): Nonnegative integer, size of the last dimension in mask. Default: None. | |||
| Inputs: | |||
| - **lengths** (Union[tuple[int], list[int]]) - Defines the first N elements that are retained. | |||
| Only constant value is allowed. | |||
| - **dtype** (mindspore.dtype) - The specified type of output tensor. Only constant value is allowed. | |||
| Outputs: | |||
| Tensor. | |||
| If max_length is set, the shape of the output is (lengths.shape, max_length). | |||
| If max_length is not set and the biggest value in lengths is x. Then, the shape of | |||
| the output is (lengths.shape, x). | |||
| Supported Platforms: | |||
| ``Ascend`` ``GPU`` ``CPU`` | |||
| Examples: | |||
| >>> from mindspore.ops import operations as P | |||
| >>> sequence_mask = P.SequenceMask() | |||
| >>> mask = sequence_mask([2, 2, 4], mindspore.int32) | |||
| >>> print(mask) | |||
| [[1, 1, 0, 0], | |||
| [1, 1, 0, 0], | |||
| [1, 1, 1, 1]] | |||
| """ | |||
| @prim_attr_register | |||
| def __init__(self): | |||
| """Initialize SequenceMask""" | |||
| def __infer__(self, lengths, dtype, max_length=None): | |||
| validator.check_value_type("shape", lengths['value'], [tuple, list], self.name) | |||
| valid_types = [mstype.bool_, mstype.int8, mstype.int16, mstype.int32, mstype.int64, | |||
| mstype.uint8, mstype.uint16, mstype.uint32, mstype.uint64, | |||
| mstype.float16, mstype.float32, mstype.float64] | |||
| validator.check_subclass("dtype", dtype['value'], valid_types, self.name) | |||
| nptype = mstype.dtype_to_nptype(dtype['value']) | |||
| if max_length is None: | |||
| max_length = np.max(lengths['value']) | |||
| else: | |||
| validator.check_non_negative_int(max_length['value']) | |||
| max_length = max_length['value'] | |||
| row_vector = np.arange(0, max_length) | |||
| col_matrix = np.expand_dims(lengths['value'], -1) | |||
| result = (row_vector < col_matrix).astype(nptype) | |||
| out = { | |||
| 'value': Tensor(result), | |||
| 'shape': result.shape, | |||
| 'dtype': dtype['value'] | |||
| } | |||
| return out | |||
| class OnesLike(PrimitiveWithInfer): | |||
| """ | |||
| Creates a new tensor. The values of all elements are 1. | |||
| @@ -4648,3 +4586,47 @@ class Identity(PrimitiveWithInfer): | |||
| 'dtype': x['dtype'], | |||
| 'value': None} | |||
| return out | |||
| class SequenceMask(PrimitiveWithCheck): | |||
| """ | |||
| Returns a mask tensor representing the first N positions of each cell. | |||
| If lengths has shape [d_1, d_2, ..., d_n], then the resulting tensor mask has type dtype and shape | |||
| [d_1, d_2, ..., d_n, maxlen], with mask[i_1, i_2, ..., i_n, j] = (j < lengths[i_1, i_2, ..., i_n]) | |||
| Inputs: | |||
| - **lengths** (Tensor) - Tensor to calculate the mask for. All values in this tensor must be | |||
| less than `maxlen`. Must be type int32 or int64. | |||
| - **maxlen** (int) - size of the last dimension of returned tensor. Must be positive and same | |||
| tyupe as elements in `lengths`. | |||
| Outputs: | |||
| One mask tensor of shape lengths.shape + (maxlen,). | |||
| Supported Platforms: | |||
| ``GPU`` | |||
| Examples: | |||
| >>> x = Tensor(np.array([[1, 3], [2, 0]]) | |||
| >>> sequence_mask = P.SequenceMask() | |||
| >>> output = sequence_mask(x, 3) | |||
| >>> print(output) | |||
| [[[True, False, False], | |||
| [True, True, True]], | |||
| [[True, True, False], | |||
| [False, False, False]]] | |||
| """ | |||
| @prim_attr_register | |||
| def __init__(self): | |||
| self.init_prim_io_names(inputs=["lengths", "maxlen"], outputs=["mask"]) | |||
| def check_shape(self, lengths_shape, maxlen_shape): | |||
| validator.check("lengths_shape", len(lengths_shape), "", 0, Rel.GT, self.name) | |||
| validator.check("maxlen_shape", len(maxlen_shape), "", 0, Rel.EQ, self.name) | |||
| def check_dtype(self, lengths_dtype, maxlen_dtype): | |||
| validator.check_subclass("lengths_dtype", lengths_dtype, mstype.tensor, self.name) | |||
| validator.check_subclass("maxlen", maxlen_dtype, mstype.number, self.name) | |||
| @@ -0,0 +1,117 @@ | |||
| import numpy as np | |||
| import pytest | |||
| from mindspore import Tensor | |||
| from mindspore.ops import operations as P | |||
| from mindspore.ops.operations import _inner_ops as inner | |||
| import mindspore.nn as nn | |||
| import mindspore.context as context | |||
| def sequence_mask(x, maxlen): | |||
| sequence_mask_op = P.SequenceMask() | |||
| return sequence_mask_op(Tensor(x.astype(np.int32)), maxlen) | |||
| @pytest.mark.level0 | |||
| @pytest.mark.platform_x86_gpu_training | |||
| @pytest.mark.env_onecard | |||
| def test_sequence_mask_1d(): | |||
| context.set_context(mode=context.PYNATIVE_MODE, device_target="GPU") | |||
| a = np.array([2, 3, 1]) | |||
| maxlen = 4 | |||
| ms_out = sequence_mask(a, maxlen) | |||
| expected_out = Tensor(np.array([[True, True, False, False], | |||
| [True, True, True, False], | |||
| [True, False, False, False]])) | |||
| np.testing.assert_array_equal(expected_out.asnumpy(), ms_out.asnumpy()) | |||
| @pytest.mark.level0 | |||
| @pytest.mark.platform_x86_gpu_training | |||
| @pytest.mark.env_onecard | |||
| def test_sequence_mask_2d(): | |||
| context.set_context(mode=context.PYNATIVE_MODE, device_target="GPU") | |||
| a = np.array([[0, 1, 3, 2], [1, 4, 4, 2]]) | |||
| maxlen = 6 | |||
| ms_out = sequence_mask(a, maxlen) | |||
| expected_out = Tensor(np.array([[[False, False, False, False, False, False], | |||
| [True, False, False, False, False, False], | |||
| [True, True, True, False, False, False], | |||
| [True, True, False, False, False, False]], | |||
| [[True, False, False, False, False, False], | |||
| [True, True, True, True, False, False], | |||
| [True, True, True, True, False, False], | |||
| [True, True, False, False, False, False]]])) | |||
| np.testing.assert_array_equal(expected_out.asnumpy(), ms_out.asnumpy()) | |||
| @pytest.mark.level0 | |||
| @pytest.mark.platform_x86_gpu_training | |||
| @pytest.mark.env_onecard | |||
| def test_sequence_mask_3d(): | |||
| context.set_context(mode=context.PYNATIVE_MODE, device_target="GPU") | |||
| a = np.array([[[2, 2], [1, 1]], | |||
| [[2, 0], [2, 1]], | |||
| [[0, 0], [0, 0]]]) | |||
| maxlen = 2 | |||
| ms_out = sequence_mask(a, maxlen) | |||
| expected_out = Tensor(np.array([[[[True, True], [True, True]], [[True, False], [True, False]]], | |||
| [[[True, True], [False, False]], [[True, True], [True, False]]], | |||
| [[[False, False], [False, False]], [[False, False], [False, False]]]])) | |||
| np.testing.assert_array_equal(expected_out.asnumpy(), ms_out.asnumpy()) | |||
| @pytest.mark.level0 | |||
| @pytest.mark.platform_x86_gpu_training | |||
| @pytest.mark.env_onecard | |||
| def test_sequence_mask_maxlen_1(): | |||
| context.set_context(mode=context.PYNATIVE_MODE, device_target="GPU") | |||
| a = np.array([[[0, 1], [1, 1]], | |||
| [[1, 0], [1, 1]], | |||
| [[0, 1], [0, 1]]]) | |||
| maxlen = 1 | |||
| ms_out = sequence_mask(a, maxlen) | |||
| expected_out = Tensor(np.array([[[[False], [True]], [[True], [True,]]], | |||
| [[[True], [False]], [[True], [True]]], | |||
| [[[False], [True]], [[False], [True]]]])) | |||
| np.testing.assert_array_equal(expected_out.asnumpy(), ms_out.asnumpy()) | |||
| @pytest.mark.level0 | |||
| @pytest.mark.platform_x86_gpu_training | |||
| @pytest.mark.env_onecard | |||
| def test_sequence_mask_dynamic(): | |||
| class SequenceMaskDynamicNet(nn.Cell): | |||
| def __init__(self, maxlen): | |||
| super(SequenceMaskDynamicNet, self).__init__() | |||
| self.maxlen = maxlen | |||
| self.convert_to_dynamic_shape = inner.GpuConvertToDynamicShape() | |||
| self.sequence_mask = P.SequenceMask() | |||
| def construct(self, x): | |||
| converted_to_dynamic_shape = self.convert_to_dynamic_shape(x) | |||
| return self.sequence_mask(converted_to_dynamic_shape, self.maxlen) | |||
| context.set_context(mode=context.GRAPH_MODE, device_target="GPU") | |||
| sequence_mask_net = SequenceMaskDynamicNet(4) | |||
| a = Tensor(np.array([0, 1, 0, 2, 0, 5])) | |||
| ms_out = sequence_mask_net(a) | |||
| expected_out = Tensor(np.array([[False, False, False, False], | |||
| [True, False, False, False], | |||
| [False, False, False, False], | |||
| [True, True, False, False], | |||
| [False, False, False, False], | |||
| [True, True, True, True]])) | |||
| np.testing.assert_array_equal(expected_out.asnumpy(), ms_out.asnumpy()) | |||
| a = Tensor(np.array([[4, 3, 0], [0, 1, 3]])) | |||
| ms_out = sequence_mask_net(a) | |||
| expected_out = Tensor(np.array([[[True, True, True, True], | |||
| [True, True, True, False], | |||
| [False, False, False, False]], | |||
| [[False, False, False, False], | |||
| [True, False, False, False], | |||
| [True, True, True, False]]])) | |||
| @@ -42,28 +42,6 @@ def test_expand_dims(): | |||
| assert output.asnumpy().shape == (1, 2, 2) | |||
| def test_sequence_mask(): | |||
| list_ = [2, 2, 4] | |||
| sequence_mask = P.SequenceMask() | |||
| mask1 = sequence_mask(list_, mstype.int32) | |||
| mask2 = sequence_mask(list_, mstype.int32, 5) | |||
| assert mask1.shape == (3, 4) | |||
| assert mask1.dtype == mstype.int32 | |||
| assert mask2.shape == (3, 5) | |||
| assert mask2.dtype == mstype.int32 | |||
| def test_sequence_mask_1(): | |||
| list_ = [[2, 2, 4], [3, 4, 4]] | |||
| sequence_mask = P.SequenceMask() | |||
| mask1 = sequence_mask(list_, mstype.bool_) | |||
| mask2 = sequence_mask(list_, mstype.bool_, 5) | |||
| assert mask1.shape == (2, 3, 4) | |||
| assert mask1.dtype == mstype.bool_ | |||
| assert mask2.shape == (2, 3, 5) | |||
| assert mask2.dtype == mstype.bool_ | |||
| def test_cast(): | |||
| input_np = np.random.randn(2, 3, 4, 5).astype(np.float32) | |||
| input_x = Tensor(input_np) | |||