From: @TFbunny Reviewed-by: @robingrosman Signed-off-by:tags/v1.2.0-rc1
| @@ -1,5 +1,5 @@ | |||
| /** | |||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||
| * Copyright 2020-2021 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| @@ -26,7 +26,7 @@ namespace kernel { | |||
| template <typename T, typename S> | |||
| class ArgmaxWithValueGpuKernel : public GpuKernel { | |||
| public: | |||
| ArgmaxWithValueGpuKernel() : input_size_(0), output_size_(0), bound_(0), outerSize_(0), innerSize_(0) {} | |||
| ArgmaxWithValueGpuKernel() { ResetResource(); } | |||
| ~ArgmaxWithValueGpuKernel() override = default; | |||
| const std::vector<size_t> &GetInputSizeList() const override { return input_size_list_; } | |||
| @@ -75,6 +75,17 @@ class ArgmaxWithValueGpuKernel : public GpuKernel { | |||
| return true; | |||
| } | |||
| void ResetResource() noexcept override { | |||
| input_size_ = 0; | |||
| output_size_ = 0; | |||
| bound_ = 0; | |||
| outerSize_ = 0; | |||
| innerSize_ = 0; | |||
| input_size_list_.clear(); | |||
| output_size_list_.clear(); | |||
| workspace_size_list_.clear(); | |||
| } | |||
| protected: | |||
| void InitSizeLists() override { | |||
| input_size_list_.push_back(input_size_); | |||
| @@ -1,35 +0,0 @@ | |||
| /** | |||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #include <cstdint> | |||
| #include "backend/kernel_compiler/gpu/arrays/sequence_mask_gpu_kernel.h" | |||
| namespace mindspore { | |||
| namespace kernel { | |||
| // keep this as TWO but output is always bool, just in case framework can | |||
| // support passing optional dtype and then we can be identical to tf | |||
| MS_REG_GPU_KERNEL_TWO( | |||
| SequenceMask, | |||
| KernelAttr().AddInputAttr(kNumberTypeInt32).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeBool), | |||
| SequenceMaskGpuKernel, int32_t, bool) | |||
| MS_REG_GPU_KERNEL_TWO( | |||
| SequenceMask, | |||
| KernelAttr().AddInputAttr(kNumberTypeInt64).AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeBool), | |||
| SequenceMaskGpuKernel, int64_t, bool) | |||
| } // namespace kernel | |||
| } // namespace mindspore | |||
| @@ -1,101 +0,0 @@ | |||
| /** | |||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #ifndef MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_SEQUENCE_MASK_GPU_KERNEL_H_ | |||
| #define MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_SEQUENCE_MASK_GPU_KERNEL_H_ | |||
| #include "backend/kernel_compiler/gpu/cuda_impl/sequence_mask_impl.cuh" | |||
| #include <cuda_runtime.h> | |||
| #include <vector> | |||
| #include "backend/kernel_compiler/gpu/gpu_kernel.h" | |||
| #include "backend/kernel_compiler/gpu/gpu_kernel_factory.h" | |||
| namespace mindspore { | |||
| namespace kernel { | |||
| template <typename T, typename S> | |||
| class SequenceMaskGpuKernel : public GpuKernel { | |||
| public: | |||
| SequenceMaskGpuKernel() { ResetResource(); } | |||
| ~SequenceMaskGpuKernel() = default; | |||
| const std::vector<size_t> &GetInputSizeList() const override { return input_size_list_; } | |||
| const std::vector<size_t> &GetOutputSizeList() const override { return output_size_list_; } | |||
| const std::vector<size_t> &GetWorkspaceSizeList() const override { return workspace_size_list_; } | |||
| bool Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &workspace, | |||
| const std::vector<AddressPtr> &outputs, void *stream_ptr) override { | |||
| T *lengths_device_address = GetDeviceAddress<T>(inputs, 0); | |||
| T *maxlen_device_address = GetDeviceAddress<T>(inputs, 1); | |||
| S *output_device_address = GetDeviceAddress<S>(outputs, 0); | |||
| CalSequenceMask(lengths_device_address, maxlen_device_address, output_device_address, output_size_, | |||
| reinterpret_cast<cudaStream_t>(stream_ptr)); | |||
| return true; | |||
| } | |||
| bool Init(const CNodePtr &kernel_node) override { | |||
| size_t input_count = AnfAlgo::GetInputTensorNum(kernel_node); | |||
| if (input_count != 2) { | |||
| MS_LOG(EXCEPTION) << input_count << " inputs were provided, but SequenceMaskGpuKernel expects 2."; | |||
| } | |||
| input_shape_ = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 0); | |||
| for (const int &e : input_shape_) { | |||
| lengths_size_ *= e; | |||
| } | |||
| std::vector<size_t> inferred_output_shape = AnfAlgo::GetOutputInferShape(kernel_node, 0); | |||
| for (const size_t &e : inferred_output_shape) { | |||
| output_size_ *= e; | |||
| } | |||
| InitSizeLists(); | |||
| return true; | |||
| } | |||
| void ResetResource() noexcept override { | |||
| output_size_ = 1; | |||
| lengths_size_ = 1; | |||
| input_size_list_.clear(); | |||
| output_size_list_.clear(); | |||
| workspace_size_list_.clear(); | |||
| } | |||
| protected: | |||
| void InitSizeLists() override { | |||
| input_size_list_.push_back(lengths_size_ * sizeof(T)); | |||
| input_size_list_.push_back(sizeof(T)); | |||
| output_size_list_.push_back(output_size_); | |||
| } | |||
| private: | |||
| std::vector<size_t> input_shape_; | |||
| size_t lengths_size_; | |||
| size_t output_size_; | |||
| std::vector<size_t> input_size_list_; | |||
| std::vector<size_t> output_size_list_; | |||
| std::vector<size_t> workspace_size_list_; | |||
| }; | |||
| } // namespace kernel | |||
| } // namespace mindspore | |||
| #endif // MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_SEQUENCE_MASK_GPU_KERNEL_H_ | |||
| @@ -1,50 +0,0 @@ | |||
| /** | |||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #include <cuda_runtime.h> | |||
| #include "sequence_mask_impl.cuh" | |||
| #include "runtime/device/gpu/cuda_common.h" | |||
| __global__ void ValidateArgs(int *maxlen, const int lengths_size, const int max_output_size) { | |||
| int maxlen_value = *maxlen; | |||
| if (maxlen_value < 0 || lengths_size * maxlen_value > max_output_size) { | |||
| asm("trap;"); | |||
| } | |||
| } | |||
| template <typename T, typename S> | |||
| __global__ void SequenceMask( | |||
| const T *input, T *maxlen, S *output, const size_t output_size) { | |||
| T maxlen_value = *maxlen; | |||
| for (size_t gt_id = blockIdx.x * blockDim.x + threadIdx.x; gt_id < output_size; gt_id += gridDim.x * blockDim.x) { | |||
| T mask_comparison_value = gt_id % maxlen_value; | |||
| T input_comparison_index = (gt_id - mask_comparison_value) / maxlen_value; | |||
| S result = mask_comparison_value < input[input_comparison_index]; | |||
| output[gt_id] = result; | |||
| } | |||
| } | |||
| template <typename T, typename S> | |||
| void CalSequenceMask(const T *lengths, T *maxlen, S *output, const size_t output_size, cudaStream_t cuda_stream) { | |||
| SequenceMask<<<GET_BLOCKS(output_size), GET_THREADS, 0, cuda_stream>>>(lengths, maxlen, output, output_size); | |||
| } | |||
| template void CalSequenceMask<int, bool>(const int *lengths, int *maxlen, bool *output, const size_t output_size, | |||
| cudaStream_t cuda_stream); | |||
| template void CalSequenceMask<int64_t, bool>(const int64_t *lengths, int64_t *maxlen, bool *output, | |||
| const size_t output_size, cudaStream_t cuda_stream); | |||
| @@ -1,25 +0,0 @@ | |||
| /** | |||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #ifndef MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_CUDA_IMPL_SEQUENCE_MASK_CUH_ | |||
| #define MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_CUDA_IMPL_SEQUENCE_MASK_CUH_ | |||
| #include <cuda_runtime.h> | |||
| template <typename T, typename S> | |||
| void CalSequenceMask(const T *lengths, T *maxlen, S *output, const size_t output_size, cudaStream_t cuda_stream); | |||
| #endif // MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_CUDA_IMPL_SEQUENCE_MASK_CUH_ | |||
| @@ -304,6 +304,10 @@ AbstractBasePtr InferImplMatMul(const AnalysisEnginePtr &, const PrimitivePtr &p | |||
| const AbstractBasePtrList &args_spec_list); | |||
| AbstractBasePtr InferImplBatchMatMul(const AnalysisEnginePtr &, const PrimitivePtr &primitive, | |||
| const AbstractBasePtrList &args_spec_list); | |||
| AbstractBasePtr InferImplLess(const AnalysisEnginePtr &, const PrimitivePtr &primitive, | |||
| const AbstractBasePtrList &args_spec_list); | |||
| AbstractBasePtr InferImplArgMaxWithValue(const AnalysisEnginePtr &, const PrimitivePtr &primitive, | |||
| const AbstractBasePtrList &args_spec_list); | |||
| template <typename T> | |||
| AbstractBasePtr InferTupleOrListOrDictLen(const std::string &op_name, const AbstractBasePtrList &args_spec_list) { | |||
| // Inputs: a tuple or list or dict. | |||
| @@ -1068,5 +1068,64 @@ AbstractBasePtr InferImplRange(const AnalysisEnginePtr &, const PrimitivePtr &pr | |||
| return std::make_shared<AbstractTensor>(range_start_type, shape); | |||
| } | |||
| AbstractBasePtr InferImplArgMaxWithValue(const AnalysisEnginePtr &, const PrimitivePtr &primitive, | |||
| const AbstractBasePtrList &args_spec_list) { | |||
| const std::string op_name = primitive->name(); | |||
| CheckArgsSize(op_name, args_spec_list, 1); | |||
| auto x = CheckArg<AbstractTensor>(op_name, args_spec_list, 0); | |||
| MS_EXCEPTION_IF_NULL(x); | |||
| MS_EXCEPTION_IF_NULL(x->shape()); | |||
| // check keep_dims | |||
| ValuePtr keep_dims = primitive->GetAttr("keep_dims"); | |||
| MS_EXCEPTION_IF_NULL(keep_dims); | |||
| if (!keep_dims->isa<BoolImm>()) { | |||
| MS_LOG(EXCEPTION) << "keep_dims should be Bool."; | |||
| } | |||
| bool keep_dims_value = GetValue<bool>(keep_dims); | |||
| // check axis | |||
| ValuePtr axis = primitive->GetAttr("axis"); | |||
| MS_EXCEPTION_IF_NULL(axis); | |||
| if (!axis->isa<Int32Imm>() && !axis->isa<Int64Imm>()) { | |||
| MS_LOG(EXCEPTION) << "axis should be Int."; | |||
| } | |||
| // check axis convert negative to positive value | |||
| auto check_axis = [](int64_t &axis, const size_t dim) -> void { | |||
| int64_t dim_ = static_cast<int64_t>(dim); | |||
| if (axis < -dim_ || axis >= dim_) { | |||
| MS_LOG(EXCEPTION) << "axis should be in [" << -dim_ << ", " << dim_ << "). But got axis = " << axis << "."; | |||
| } | |||
| if (axis >= -dim_ && axis < 0) { | |||
| axis += dim_; | |||
| } | |||
| return; | |||
| }; | |||
| // main calculate shape func | |||
| auto cal_shape = [axis, keep_dims_value, check_axis](ShapeVector &shape, const ShapeVector &x_shape) -> void { | |||
| shape.insert(shape.end(), x_shape.begin(), x_shape.end()); | |||
| int64_t axis_value = GetValue<int64_t>(axis); | |||
| check_axis(axis_value, x_shape.size()); | |||
| if (keep_dims_value) { | |||
| shape[axis_value] = 1; | |||
| } else { | |||
| shape.erase(std::begin(shape) + axis_value); | |||
| } | |||
| }; | |||
| ShapeVector shape = {}; | |||
| ShapeVector min_shape = {}; | |||
| ShapeVector max_shape = {}; | |||
| ShapeVector x_shape = x->shape()->shape(); | |||
| ShapeVector x_min_shape = x->shape()->min_shape(); | |||
| ShapeVector x_max_shape = x->shape()->max_shape(); | |||
| (void)CheckMinMaxShape(x_shape, &x_min_shape, &x_max_shape); | |||
| cal_shape(shape, x_shape); | |||
| cal_shape(min_shape, x_min_shape); | |||
| cal_shape(max_shape, x_max_shape); | |||
| TypePtr idx_type = kInt32; | |||
| auto index = std::make_shared<AbstractTensor>(idx_type, std::make_shared<Shape>(shape, min_shape, max_shape)); | |||
| auto value = std::make_shared<AbstractTensor>(x->element(), std::make_shared<Shape>(shape, min_shape, max_shape)); | |||
| AbstractBasePtrList result = {index, value}; | |||
| return std::make_shared<AbstractTuple>(result); | |||
| } | |||
| } // namespace abstract | |||
| } // namespace mindspore | |||
| @@ -1,5 +1,5 @@ | |||
| /** | |||
| * Copyright 2019-2020 Huawei Technologies Co., Ltd | |||
| * Copyright 2019-2021 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| @@ -454,5 +454,37 @@ AbstractBasePtr InferImplBatchMatMul(const AnalysisEnginePtr &, const PrimitiveP | |||
| } | |||
| return std::make_shared<AbstractTensor>(x_type, std::make_shared<Shape>(ret_shape, ret_min_shape, ret_max_shape)); | |||
| } | |||
| AbstractBasePtr InferImplLess(const AnalysisEnginePtr &, const PrimitivePtr &primitive, | |||
| const AbstractBasePtrList &args_spec_list) { | |||
| const std::string op_name = primitive->name(); | |||
| CheckArgsSize(op_name, args_spec_list, 2); | |||
| auto x = CheckArg<AbstractTensor>(op_name, args_spec_list, 0); | |||
| MS_EXCEPTION_IF_NULL(x); | |||
| MS_EXCEPTION_IF_NULL(x->shape()); | |||
| ShapeVector x_shape = x->shape()->shape(); | |||
| ShapeVector x_shape_min = x->shape()->min_shape().empty() ? x_shape : x->shape()->min_shape(); | |||
| ShapeVector x_shape_max = x->shape()->max_shape().empty() ? x_shape : x->shape()->max_shape(); | |||
| auto y = CheckArg<AbstractTensor>(op_name, args_spec_list, 1); | |||
| MS_EXCEPTION_IF_NULL(y); | |||
| MS_EXCEPTION_IF_NULL(y->shape()); | |||
| ShapeVector y_shape = y->shape()->shape(); | |||
| ShapeVector y_shape_min = y->shape()->min_shape().empty() ? y_shape : y->shape()->min_shape(); | |||
| ShapeVector y_shape_max = y->shape()->max_shape().empty() ? y_shape : y->shape()->max_shape(); | |||
| auto out_shape = BroadcastShape(x_shape, y_shape); | |||
| if (out_shape.empty()) { | |||
| MS_LOG(EXCEPTION) << "BroadcastShape fail: " << args_spec_list[0]->ToString() << "," | |||
| << args_spec_list[1]->ToString(); | |||
| } | |||
| auto out_shape_min = BroadcastShape(x_shape_min, y_shape_min); | |||
| auto out_shape_max = BroadcastShape(x_shape_max, y_shape_max); | |||
| auto output_type = std::make_shared<Bool>(); | |||
| auto ret = | |||
| std::make_shared<AbstractTensor>(output_type, std::make_shared<Shape>(out_shape, out_shape_min, out_shape_max)); | |||
| return ret; | |||
| } | |||
| } // namespace abstract | |||
| } // namespace mindspore | |||
| @@ -74,6 +74,7 @@ PrimitiveEvalImplMap &GetPrimitiveToEvalImplMap() { | |||
| {prim::kPrimAddN, {InferImplAddN, true}}, | |||
| {prim::kPrimMatMul, {InferImplMatMul, true}}, | |||
| {prim::kPrimBatchMatMul, {InferImplBatchMatMul, true}}, | |||
| {prim::kPrimLess, {InferImplLess, true}}, | |||
| // Array | |||
| {prim::kPrimScalarToArray, {InferImplScalarToArray, true}}, | |||
| {prim::kPrimArrayToScalar, {InferImplArrayToScalar, true}}, | |||
| @@ -108,6 +109,7 @@ PrimitiveEvalImplMap &GetPrimitiveToEvalImplMap() { | |||
| {prim::kPrimSequenceMask, {InferImplSequenceMask, true}}, | |||
| {prim::kPrimConcat, {InferImplConcat, true}}, | |||
| {prim::kPrimRange, {InferImplRange, true}}, | |||
| {prim::kPrimArgMaxWithValue, {InferImplArgMaxWithValue, true}}, | |||
| // Structure | |||
| {prim::kPrimMakeTuple, {InferImplMakeTuple, true}}, | |||
| {prim::kPrimMakeList, {InferImplMakeList, true}}, | |||
| @@ -1,4 +1,4 @@ | |||
| # Copyright 2020 Huawei Technologies Co., Ltd | |||
| # Copyright 2020-2021 Huawei Technologies Co., Ltd | |||
| # | |||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||
| # you may not use this file except in compliance with the License. | |||
| @@ -20,7 +20,6 @@ from mindspore._checkparam import Rel | |||
| from mindspore.ops.primitive import constexpr | |||
| from mindspore.ops import functional as F | |||
| from .. import operations as P | |||
| from ..operations import _inner_ops as inner | |||
| @constexpr | |||
| @@ -105,7 +104,15 @@ def repeat_elements(x, rep, axis=0): | |||
| return x_rep | |||
| def sequence_mask(lengths, maxlen): | |||
| @constexpr | |||
| def _check_sequence_mask_input_len(input_shape): | |||
| if not input_shape: | |||
| raise ValueError(f"sequence_mask input lengths_shape should be > 0. " | |||
| f"current lengths_shape is {input_shape}.") | |||
| def sequence_mask(lengths, maxlen=None): | |||
| """ | |||
| Returns a mask tensor representing the first N positions of each cell. | |||
| @@ -135,4 +142,29 @@ def sequence_mask(lengths, maxlen): | |||
| [[True, True, False], | |||
| [False, False, False]]] | |||
| """ | |||
| return inner.SequenceMask()(lengths, maxlen) | |||
| argmax_op = P.ArgMaxWithValue() | |||
| reshape_op = P.Reshape() | |||
| range_op = P.Range() | |||
| expand_op = P.ExpandDims() | |||
| cast_op = P.Cast() | |||
| shape_op = P.Shape() | |||
| to_tensor_op = P.ScalarToArray() | |||
| const_utils.check_type_valid(F.dtype(lengths), [mstype.int64, mstype.int32], 'lengths') | |||
| _check_sequence_mask_input_len(shape_op(lengths)) | |||
| if maxlen is None: | |||
| flatten_data = reshape_op(lengths, (-1,)) | |||
| flatten_data = cast_op(flatten_data, mstype.float32) | |||
| _, value = argmax_op(flatten_data) | |||
| maxlen = cast_op(value, mstype.int32) | |||
| else: | |||
| maxlen = _check_positive_int(maxlen, "maxlen", "sequence_mask") | |||
| maxlen = to_tensor_op(maxlen) | |||
| range_vector = range_op(to_tensor_op(0), maxlen | |||
| , to_tensor_op(1)) | |||
| mask = expand_op(lengths, -1) | |||
| result = range_vector < mask | |||
| return result | |||
| @@ -1,3 +1,17 @@ | |||
| # Copyright 2020-2021 Huawei Technologies Co., Ltd | |||
| # | |||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||
| # you may not use this file except in compliance with the License. | |||
| # You may obtain a copy of the License at | |||
| # | |||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||
| # | |||
| # Unless required by applicable law or agreed to in writing, software | |||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| # See the License for the specific language governing permissions and | |||
| # limitations under the License. | |||
| # ============================================================================ | |||
| import numpy as np | |||
| import pytest | |||
| @@ -16,7 +30,6 @@ def sequence_mask(x, maxlen): | |||
| def test_sequence_mask_1d(): | |||
| context.set_context(mode=context.PYNATIVE_MODE, device_target="GPU") | |||
| a = np.array([2, 3, 1]) | |||
| maxlen = 4 | |||
| ms_out = sequence_mask(a, maxlen) | |||
| expected_out = Tensor(np.array([[True, True, False, False], | |||
| @@ -30,7 +43,6 @@ def test_sequence_mask_1d(): | |||
| def test_sequence_mask_2d(): | |||
| context.set_context(mode=context.PYNATIVE_MODE, device_target="GPU") | |||
| a = np.array([[0, 1, 3, 2], [1, 4, 4, 2]]) | |||
| maxlen = 6 | |||
| ms_out = sequence_mask(a, maxlen) | |||
| expected_out = Tensor(np.array([[[False, False, False, False, False, False], | |||
| @@ -51,7 +63,6 @@ def test_sequence_mask_3d(): | |||
| a = np.array([[[2, 2], [1, 1]], | |||
| [[2, 0], [2, 1]], | |||
| [[0, 0], [0, 0]]]) | |||
| maxlen = 2 | |||
| ms_out = sequence_mask(a, maxlen) | |||
| expected_out = Tensor(np.array([[[[True, True], [True, True]], [[True, False], [True, False]]], | |||
| @@ -68,7 +79,6 @@ def test_sequence_mask_maxlen_1(): | |||
| a = np.array([[[0, 1], [1, 1]], | |||
| [[1, 0], [1, 1]], | |||
| [[0, 1], [0, 1]]]) | |||
| maxlen = 1 | |||
| ms_out = sequence_mask(a, maxlen) | |||
| expected_out = Tensor(np.array([[[[False], [True]], [[True], [True,]]], | |||
| @@ -81,9 +91,9 @@ def test_sequence_mask_maxlen_1(): | |||
| @pytest.mark.platform_x86_gpu_training | |||
| @pytest.mark.env_onecard | |||
| def test_sequence_mask_dynamic(): | |||
| class SequenceMaskDynamicNet(nn.Cell): | |||
| class SequenceMaskDynamicNet1(nn.Cell): | |||
| def __init__(self, maxlen): | |||
| super(SequenceMaskDynamicNet, self).__init__() | |||
| super(SequenceMaskDynamicNet1, self).__init__() | |||
| self.maxlen = maxlen | |||
| self.convert_to_dynamic_shape = inner.GpuConvertToDynamicShape() | |||
| @@ -91,9 +101,18 @@ def test_sequence_mask_dynamic(): | |||
| converted_to_dynamic_shape = self.convert_to_dynamic_shape(x) | |||
| return C.sequence_mask(converted_to_dynamic_shape, self.maxlen) | |||
| class SequenceMaskDynamicNet2(nn.Cell): | |||
| def __init__(self): | |||
| super(SequenceMaskDynamicNet2, self).__init__() | |||
| self.convert_to_dynamic_shape = inner.GpuConvertToDynamicShape() | |||
| def construct(self, x): | |||
| converted_to_dynamic_shape = self.convert_to_dynamic_shape(x) | |||
| return C.sequence_mask(converted_to_dynamic_shape) | |||
| context.set_context(mode=context.GRAPH_MODE, device_target="GPU") | |||
| sequence_mask_net = SequenceMaskDynamicNet(4) | |||
| sequence_mask_net = SequenceMaskDynamicNet1(4) | |||
| a = Tensor(np.array([0, 1, 0, 2, 0, 5])) | |||
| ms_out = sequence_mask_net(a) | |||
| @@ -113,3 +132,39 @@ def test_sequence_mask_dynamic(): | |||
| [[False, False, False, False], | |||
| [True, False, False, False], | |||
| [True, True, True, False]]])) | |||
| np.testing.assert_array_equal(expected_out.asnumpy(), ms_out.asnumpy()) | |||
| net_without_maxlen = SequenceMaskDynamicNet2() | |||
| context.set_context(mode=context.GRAPH_MODE, device_target="GPU") | |||
| a = np.array([2, 3, 1]) | |||
| ms_out = net_without_maxlen(Tensor(a)) | |||
| expected_out = Tensor(np.array([[True, True, False], | |||
| [True, True, True], | |||
| [True, False, False]])) | |||
| np.testing.assert_array_equal(expected_out.asnumpy(), ms_out.asnumpy()) | |||
| def sequence_mask_optional(x): | |||
| return C.sequence_mask(Tensor(x.astype(np.int32))) | |||
| @pytest.mark.level0 | |||
| @pytest.mark.platform_x86_gpu_training | |||
| @pytest.mark.env_onecard | |||
| def test_sequence_mask_optional_maxlen(): | |||
| context.set_context(mode=context.GRAPH_MODE, device_target="GPU") | |||
| a = np.array([2, 3, 1]) | |||
| ms_out = sequence_mask_optional(a) | |||
| expected_out = Tensor(np.array([[True, True, False], | |||
| [True, True, True], | |||
| [True, False, False]])) | |||
| np.testing.assert_array_equal(expected_out.asnumpy(), ms_out.asnumpy()) | |||
| context.set_context(mode=context.PYNATIVE_MODE, device_target="GPU") | |||
| a = np.array([2, 3, 1]) | |||
| ms_out = sequence_mask_optional(a) | |||
| expected_out = Tensor(np.array([[True, True, False], | |||
| [True, True, True], | |||
| [True, False, False]])) | |||
| np.testing.assert_array_equal(expected_out.asnumpy(), ms_out.asnumpy()) | |||