From 7f769568a64d3f73635d46b9bdf29d10f501d5ef Mon Sep 17 00:00:00 2001 From: TFbunny Date: Thu, 5 Nov 2020 12:04:03 -0500 Subject: [PATCH] Register SparseGatherV2 and add dynamic shape support --- .../gpu/arrays/gatherv2_gpu_kernel.cc | 1 - tests/st/ops/gpu/test_sparse_gather_v2_op.py | 939 ++++++++++++++++++ 2 files changed, 939 insertions(+), 1 deletion(-) create mode 100644 tests/st/ops/gpu/test_sparse_gather_v2_op.py diff --git a/mindspore/ccsrc/backend/kernel_compiler/gpu/arrays/gatherv2_gpu_kernel.cc b/mindspore/ccsrc/backend/kernel_compiler/gpu/arrays/gatherv2_gpu_kernel.cc index a54627b41f..c64696de6e 100644 --- a/mindspore/ccsrc/backend/kernel_compiler/gpu/arrays/gatherv2_gpu_kernel.cc +++ b/mindspore/ccsrc/backend/kernel_compiler/gpu/arrays/gatherv2_gpu_kernel.cc @@ -26,7 +26,6 @@ MS_REG_GPU_KERNEL_TWO( GatherV2, KernelAttr().AddInputAttr(kNumberTypeFloat16).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeFloat16), GatherV2GpuFwdKernel, half, int) - MS_REG_GPU_KERNEL_TWO( SparseGatherV2, KernelAttr().AddInputAttr(kNumberTypeFloat32).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeFloat32), diff --git a/tests/st/ops/gpu/test_sparse_gather_v2_op.py b/tests/st/ops/gpu/test_sparse_gather_v2_op.py new file mode 100644 index 0000000000..2b505c4366 --- /dev/null +++ b/tests/st/ops/gpu/test_sparse_gather_v2_op.py @@ -0,0 +1,939 @@ +# Copyright 2019-2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ + +import numpy as np +import pytest + +import mindspore.context as context +import mindspore.nn as nn +from mindspore import Tensor +from mindspore.ops import operations as P + + +class SparseGatherNet(nn.Cell): + def __init__(self): + super(SparseGatherNet, self).__init__() + self.gather = P.SparseGatherV2() + + def construct(self, x, indices): + return self.gather(x, indices, 1) + + +@pytest.mark.level0 +@pytest.mark.platform_x86_gpu_training +@pytest.mark.env_onecard +def test_gather0(): + x = Tensor(np.arange(2 * 3 * 4 * 5, dtype=np.float32).reshape(2, 3, 4, 5)) + indices = Tensor(np.ones((2, 2, 4, 5), dtype='i4')) + expect = np.array([[[[[[[20., 21., 22., 23., 24.], + [25., 26., 27., 28., 29.], + [30., 31., 32., 33., 34.], + [35., 36., 37., 38., 39.]], + + [[20., 21., 22., 23., 24.], + [25., 26., 27., 28., 29.], + [30., 31., 32., 33., 34.], + [35., 36., 37., 38., 39.]], + + [[20., 21., 22., 23., 24.], + [25., 26., 27., 28., 29.], + [30., 31., 32., 33., 34.], + [35., 36., 37., 38., 39.]], + + [[20., 21., 22., 23., 24.], + [25., 26., 27., 28., 29.], + [30., 31., 32., 33., 34.], + [35., 36., 37., 38., 39.]], + + [[20., 21., 22., 23., 24.], + [25., 26., 27., 28., 29.], + [30., 31., 32., 33., 34.], + [35., 36., 37., 38., 39.]]], + + [[[20., 21., 22., 23., 24.], + [25., 26., 27., 28., 29.], + [30., 31., 32., 33., 34.], + [35., 36., 37., 38., 39.]], + + [[20., 21., 22., 23., 24.], + [25., 26., 27., 28., 29.], + [30., 31., 32., 33., 34.], + [35., 36., 37., 38., 39.]], + + [[20., 21., 22., 23., 24.], + [25., 26., 27., 28., 29.], + [30., 31., 32., 33., 34.], + [35., 36., 37., 38., 39.]], + + [[20., 21., 22., 23., 24.], + [25., 26., 27., 28., 29.], + [30., 31., 32., 33., 34.], + [35., 36., 37., 38., 39.]], + + [[20., 21., 22., 23., 24.], + [25., 26., 27., 28., 29.], + [30., 31., 32., 33., 34.], + [35., 36., 37., 38., 39.]]], + + [[[20., 21., 22., 23., 24.], + [25., 26., 27., 28., 29.], + [30., 31., 32., 33., 34.], + [35., 36., 37., 38., 39.]], + + [[20., 21., 22., 23., 24.], + [25., 26., 27., 28., 29.], + [30., 31., 32., 33., 34.], + [35., 36., 37., 38., 39.]], + + [[20., 21., 22., 23., 24.], + [25., 26., 27., 28., 29.], + [30., 31., 32., 33., 34.], + [35., 36., 37., 38., 39.]], + + [[20., 21., 22., 23., 24.], + [25., 26., 27., 28., 29.], + [30., 31., 32., 33., 34.], + [35., 36., 37., 38., 39.]], + + [[20., 21., 22., 23., 24.], + [25., 26., 27., 28., 29.], + [30., 31., 32., 33., 34.], + [35., 36., 37., 38., 39.]]], + + [[[20., 21., 22., 23., 24.], + [25., 26., 27., 28., 29.], + [30., 31., 32., 33., 34.], + [35., 36., 37., 38., 39.]], + + [[20., 21., 22., 23., 24.], + [25., 26., 27., 28., 29.], + [30., 31., 32., 33., 34.], + [35., 36., 37., 38., 39.]], + + [[20., 21., 22., 23., 24.], + [25., 26., 27., 28., 29.], + [30., 31., 32., 33., 34.], + [35., 36., 37., 38., 39.]], + + [[20., 21., 22., 23., 24.], + [25., 26., 27., 28., 29.], + [30., 31., 32., 33., 34.], + [35., 36., 37., 38., 39.]], + + [[20., 21., 22., 23., 24.], + [25., 26., 27., 28., 29.], + [30., 31., 32., 33., 34.], + [35., 36., 37., 38., 39.]]]], + + [[[[20., 21., 22., 23., 24.], + [25., 26., 27., 28., 29.], + [30., 31., 32., 33., 34.], + [35., 36., 37., 38., 39.]], + + [[20., 21., 22., 23., 24.], + [25., 26., 27., 28., 29.], + [30., 31., 32., 33., 34.], + [35., 36., 37., 38., 39.]], + + [[20., 21., 22., 23., 24.], + [25., 26., 27., 28., 29.], + [30., 31., 32., 33., 34.], + [35., 36., 37., 38., 39.]], + + [[20., 21., 22., 23., 24.], + [25., 26., 27., 28., 29.], + [30., 31., 32., 33., 34.], + [35., 36., 37., 38., 39.]], + + [[20., 21., 22., 23., 24.], + [25., 26., 27., 28., 29.], + [30., 31., 32., 33., 34.], + [35., 36., 37., 38., 39.]]], + + [[[20., 21., 22., 23., 24.], + [25., 26., 27., 28., 29.], + [30., 31., 32., 33., 34.], + [35., 36., 37., 38., 39.]], + + [[20., 21., 22., 23., 24.], + [25., 26., 27., 28., 29.], + [30., 31., 32., 33., 34.], + [35., 36., 37., 38., 39.]], + + [[20., 21., 22., 23., 24.], + [25., 26., 27., 28., 29.], + [30., 31., 32., 33., 34.], + [35., 36., 37., 38., 39.]], + + [[20., 21., 22., 23., 24.], + [25., 26., 27., 28., 29.], + [30., 31., 32., 33., 34.], + [35., 36., 37., 38., 39.]], + + [[20., 21., 22., 23., 24.], + [25., 26., 27., 28., 29.], + [30., 31., 32., 33., 34.], + [35., 36., 37., 38., 39.]]], + + [[[20., 21., 22., 23., 24.], + [25., 26., 27., 28., 29.], + [30., 31., 32., 33., 34.], + [35., 36., 37., 38., 39.]], + + [[20., 21., 22., 23., 24.], + [25., 26., 27., 28., 29.], + [30., 31., 32., 33., 34.], + [35., 36., 37., 38., 39.]], + + [[20., 21., 22., 23., 24.], + [25., 26., 27., 28., 29.], + [30., 31., 32., 33., 34.], + [35., 36., 37., 38., 39.]], + + [[20., 21., 22., 23., 24.], + [25., 26., 27., 28., 29.], + [30., 31., 32., 33., 34.], + [35., 36., 37., 38., 39.]], + + [[20., 21., 22., 23., 24.], + [25., 26., 27., 28., 29.], + [30., 31., 32., 33., 34.], + [35., 36., 37., 38., 39.]]], + + [[[20., 21., 22., 23., 24.], + [25., 26., 27., 28., 29.], + [30., 31., 32., 33., 34.], + [35., 36., 37., 38., 39.]], + + [[20., 21., 22., 23., 24.], + [25., 26., 27., 28., 29.], + [30., 31., 32., 33., 34.], + [35., 36., 37., 38., 39.]], + + [[20., 21., 22., 23., 24.], + [25., 26., 27., 28., 29.], + [30., 31., 32., 33., 34.], + [35., 36., 37., 38., 39.]], + + [[20., 21., 22., 23., 24.], + [25., 26., 27., 28., 29.], + [30., 31., 32., 33., 34.], + [35., 36., 37., 38., 39.]], + + [[20., 21., 22., 23., 24.], + [25., 26., 27., 28., 29.], + [30., 31., 32., 33., 34.], + [35., 36., 37., 38., 39.]]]]], + + [[[[[20., 21., 22., 23., 24.], + [25., 26., 27., 28., 29.], + [30., 31., 32., 33., 34.], + [35., 36., 37., 38., 39.]], + + [[20., 21., 22., 23., 24.], + [25., 26., 27., 28., 29.], + [30., 31., 32., 33., 34.], + [35., 36., 37., 38., 39.]], + + [[20., 21., 22., 23., 24.], + [25., 26., 27., 28., 29.], + [30., 31., 32., 33., 34.], + [35., 36., 37., 38., 39.]], + + [[20., 21., 22., 23., 24.], + [25., 26., 27., 28., 29.], + [30., 31., 32., 33., 34.], + [35., 36., 37., 38., 39.]], + + [[20., 21., 22., 23., 24.], + [25., 26., 27., 28., 29.], + [30., 31., 32., 33., 34.], + [35., 36., 37., 38., 39.]]], + + [[[20., 21., 22., 23., 24.], + [25., 26., 27., 28., 29.], + [30., 31., 32., 33., 34.], + [35., 36., 37., 38., 39.]], + + [[20., 21., 22., 23., 24.], + [25., 26., 27., 28., 29.], + [30., 31., 32., 33., 34.], + [35., 36., 37., 38., 39.]], + + [[20., 21., 22., 23., 24.], + [25., 26., 27., 28., 29.], + [30., 31., 32., 33., 34.], + [35., 36., 37., 38., 39.]], + + [[20., 21., 22., 23., 24.], + [25., 26., 27., 28., 29.], + [30., 31., 32., 33., 34.], + [35., 36., 37., 38., 39.]], + + [[20., 21., 22., 23., 24.], + [25., 26., 27., 28., 29.], + [30., 31., 32., 33., 34.], + [35., 36., 37., 38., 39.]]], + + [[[20., 21., 22., 23., 24.], + [25., 26., 27., 28., 29.], + [30., 31., 32., 33., 34.], + [35., 36., 37., 38., 39.]], + + [[20., 21., 22., 23., 24.], + [25., 26., 27., 28., 29.], + [30., 31., 32., 33., 34.], + [35., 36., 37., 38., 39.]], + + [[20., 21., 22., 23., 24.], + [25., 26., 27., 28., 29.], + [30., 31., 32., 33., 34.], + [35., 36., 37., 38., 39.]], + + [[20., 21., 22., 23., 24.], + [25., 26., 27., 28., 29.], + [30., 31., 32., 33., 34.], + [35., 36., 37., 38., 39.]], + + [[20., 21., 22., 23., 24.], + [25., 26., 27., 28., 29.], + [30., 31., 32., 33., 34.], + [35., 36., 37., 38., 39.]]], + + [[[20., 21., 22., 23., 24.], + [25., 26., 27., 28., 29.], + [30., 31., 32., 33., 34.], + [35., 36., 37., 38., 39.]], + + [[20., 21., 22., 23., 24.], + [25., 26., 27., 28., 29.], + [30., 31., 32., 33., 34.], + [35., 36., 37., 38., 39.]], + + [[20., 21., 22., 23., 24.], + [25., 26., 27., 28., 29.], + [30., 31., 32., 33., 34.], + [35., 36., 37., 38., 39.]], + + [[20., 21., 22., 23., 24.], + [25., 26., 27., 28., 29.], + [30., 31., 32., 33., 34.], + [35., 36., 37., 38., 39.]], + + [[20., 21., 22., 23., 24.], + [25., 26., 27., 28., 29.], + [30., 31., 32., 33., 34.], + [35., 36., 37., 38., 39.]]]], + + [[[[20., 21., 22., 23., 24.], + [25., 26., 27., 28., 29.], + [30., 31., 32., 33., 34.], + [35., 36., 37., 38., 39.]], + + [[20., 21., 22., 23., 24.], + [25., 26., 27., 28., 29.], + [30., 31., 32., 33., 34.], + [35., 36., 37., 38., 39.]], + + [[20., 21., 22., 23., 24.], + [25., 26., 27., 28., 29.], + [30., 31., 32., 33., 34.], + [35., 36., 37., 38., 39.]], + + [[20., 21., 22., 23., 24.], + [25., 26., 27., 28., 29.], + [30., 31., 32., 33., 34.], + [35., 36., 37., 38., 39.]], + + [[20., 21., 22., 23., 24.], + [25., 26., 27., 28., 29.], + [30., 31., 32., 33., 34.], + [35., 36., 37., 38., 39.]]], + + [[[20., 21., 22., 23., 24.], + [25., 26., 27., 28., 29.], + [30., 31., 32., 33., 34.], + [35., 36., 37., 38., 39.]], + + [[20., 21., 22., 23., 24.], + [25., 26., 27., 28., 29.], + [30., 31., 32., 33., 34.], + [35., 36., 37., 38., 39.]], + + [[20., 21., 22., 23., 24.], + [25., 26., 27., 28., 29.], + [30., 31., 32., 33., 34.], + [35., 36., 37., 38., 39.]], + + [[20., 21., 22., 23., 24.], + [25., 26., 27., 28., 29.], + [30., 31., 32., 33., 34.], + [35., 36., 37., 38., 39.]], + + [[20., 21., 22., 23., 24.], + [25., 26., 27., 28., 29.], + [30., 31., 32., 33., 34.], + [35., 36., 37., 38., 39.]]], + + [[[20., 21., 22., 23., 24.], + [25., 26., 27., 28., 29.], + [30., 31., 32., 33., 34.], + [35., 36., 37., 38., 39.]], + + [[20., 21., 22., 23., 24.], + [25., 26., 27., 28., 29.], + [30., 31., 32., 33., 34.], + [35., 36., 37., 38., 39.]], + + [[20., 21., 22., 23., 24.], + [25., 26., 27., 28., 29.], + [30., 31., 32., 33., 34.], + [35., 36., 37., 38., 39.]], + + [[20., 21., 22., 23., 24.], + [25., 26., 27., 28., 29.], + [30., 31., 32., 33., 34.], + [35., 36., 37., 38., 39.]], + + [[20., 21., 22., 23., 24.], + [25., 26., 27., 28., 29.], + [30., 31., 32., 33., 34.], + [35., 36., 37., 38., 39.]]], + + [[[20., 21., 22., 23., 24.], + [25., 26., 27., 28., 29.], + [30., 31., 32., 33., 34.], + [35., 36., 37., 38., 39.]], + + [[20., 21., 22., 23., 24.], + [25., 26., 27., 28., 29.], + [30., 31., 32., 33., 34.], + [35., 36., 37., 38., 39.]], + + [[20., 21., 22., 23., 24.], + [25., 26., 27., 28., 29.], + [30., 31., 32., 33., 34.], + [35., 36., 37., 38., 39.]], + + [[20., 21., 22., 23., 24.], + [25., 26., 27., 28., 29.], + [30., 31., 32., 33., 34.], + [35., 36., 37., 38., 39.]], + + [[20., 21., 22., 23., 24.], + [25., 26., 27., 28., 29.], + [30., 31., 32., 33., 34.], + [35., 36., 37., 38., 39.]]]]]], + + [[[[[[80., 81., 82., 83., 84.], + [85., 86., 87., 88., 89.], + [90., 91., 92., 93., 94.], + [95., 96., 97., 98., 99.]], + + [[80., 81., 82., 83., 84.], + [85., 86., 87., 88., 89.], + [90., 91., 92., 93., 94.], + [95., 96., 97., 98., 99.]], + + [[80., 81., 82., 83., 84.], + [85., 86., 87., 88., 89.], + [90., 91., 92., 93., 94.], + [95., 96., 97., 98., 99.]], + + [[80., 81., 82., 83., 84.], + [85., 86., 87., 88., 89.], + [90., 91., 92., 93., 94.], + [95., 96., 97., 98., 99.]], + + [[80., 81., 82., 83., 84.], + [85., 86., 87., 88., 89.], + [90., 91., 92., 93., 94.], + [95., 96., 97., 98., 99.]]], + + [[[80., 81., 82., 83., 84.], + [85., 86., 87., 88., 89.], + [90., 91., 92., 93., 94.], + [95., 96., 97., 98., 99.]], + + [[80., 81., 82., 83., 84.], + [85., 86., 87., 88., 89.], + [90., 91., 92., 93., 94.], + [95., 96., 97., 98., 99.]], + + [[80., 81., 82., 83., 84.], + [85., 86., 87., 88., 89.], + [90., 91., 92., 93., 94.], + [95., 96., 97., 98., 99.]], + + [[80., 81., 82., 83., 84.], + [85., 86., 87., 88., 89.], + [90., 91., 92., 93., 94.], + [95., 96., 97., 98., 99.]], + + [[80., 81., 82., 83., 84.], + [85., 86., 87., 88., 89.], + [90., 91., 92., 93., 94.], + [95., 96., 97., 98., 99.]]], + + [[[80., 81., 82., 83., 84.], + [85., 86., 87., 88., 89.], + [90., 91., 92., 93., 94.], + [95., 96., 97., 98., 99.]], + + [[80., 81., 82., 83., 84.], + [85., 86., 87., 88., 89.], + [90., 91., 92., 93., 94.], + [95., 96., 97., 98., 99.]], + + [[80., 81., 82., 83., 84.], + [85., 86., 87., 88., 89.], + [90., 91., 92., 93., 94.], + [95., 96., 97., 98., 99.]], + + [[80., 81., 82., 83., 84.], + [85., 86., 87., 88., 89.], + [90., 91., 92., 93., 94.], + [95., 96., 97., 98., 99.]], + + [[80., 81., 82., 83., 84.], + [85., 86., 87., 88., 89.], + [90., 91., 92., 93., 94.], + [95., 96., 97., 98., 99.]]], + + [[[80., 81., 82., 83., 84.], + [85., 86., 87., 88., 89.], + [90., 91., 92., 93., 94.], + [95., 96., 97., 98., 99.]], + + [[80., 81., 82., 83., 84.], + [85., 86., 87., 88., 89.], + [90., 91., 92., 93., 94.], + [95., 96., 97., 98., 99.]], + + [[80., 81., 82., 83., 84.], + [85., 86., 87., 88., 89.], + [90., 91., 92., 93., 94.], + [95., 96., 97., 98., 99.]], + + [[80., 81., 82., 83., 84.], + [85., 86., 87., 88., 89.], + [90., 91., 92., 93., 94.], + [95., 96., 97., 98., 99.]], + + [[80., 81., 82., 83., 84.], + [85., 86., 87., 88., 89.], + [90., 91., 92., 93., 94.], + [95., 96., 97., 98., 99.]]]], + + [[[[80., 81., 82., 83., 84.], + [85., 86., 87., 88., 89.], + [90., 91., 92., 93., 94.], + [95., 96., 97., 98., 99.]], + + [[80., 81., 82., 83., 84.], + [85., 86., 87., 88., 89.], + [90., 91., 92., 93., 94.], + [95., 96., 97., 98., 99.]], + + [[80., 81., 82., 83., 84.], + [85., 86., 87., 88., 89.], + [90., 91., 92., 93., 94.], + [95., 96., 97., 98., 99.]], + + [[80., 81., 82., 83., 84.], + [85., 86., 87., 88., 89.], + [90., 91., 92., 93., 94.], + [95., 96., 97., 98., 99.]], + + [[80., 81., 82., 83., 84.], + [85., 86., 87., 88., 89.], + [90., 91., 92., 93., 94.], + [95., 96., 97., 98., 99.]]], + + [[[80., 81., 82., 83., 84.], + [85., 86., 87., 88., 89.], + [90., 91., 92., 93., 94.], + [95., 96., 97., 98., 99.]], + + [[80., 81., 82., 83., 84.], + [85., 86., 87., 88., 89.], + [90., 91., 92., 93., 94.], + [95., 96., 97., 98., 99.]], + + [[80., 81., 82., 83., 84.], + [85., 86., 87., 88., 89.], + [90., 91., 92., 93., 94.], + [95., 96., 97., 98., 99.]], + + [[80., 81., 82., 83., 84.], + [85., 86., 87., 88., 89.], + [90., 91., 92., 93., 94.], + [95., 96., 97., 98., 99.]], + + [[80., 81., 82., 83., 84.], + [85., 86., 87., 88., 89.], + [90., 91., 92., 93., 94.], + [95., 96., 97., 98., 99.]]], + + [[[80., 81., 82., 83., 84.], + [85., 86., 87., 88., 89.], + [90., 91., 92., 93., 94.], + [95., 96., 97., 98., 99.]], + + [[80., 81., 82., 83., 84.], + [85., 86., 87., 88., 89.], + [90., 91., 92., 93., 94.], + [95., 96., 97., 98., 99.]], + + [[80., 81., 82., 83., 84.], + [85., 86., 87., 88., 89.], + [90., 91., 92., 93., 94.], + [95., 96., 97., 98., 99.]], + + [[80., 81., 82., 83., 84.], + [85., 86., 87., 88., 89.], + [90., 91., 92., 93., 94.], + [95., 96., 97., 98., 99.]], + + [[80., 81., 82., 83., 84.], + [85., 86., 87., 88., 89.], + [90., 91., 92., 93., 94.], + [95., 96., 97., 98., 99.]]], + + [[[80., 81., 82., 83., 84.], + [85., 86., 87., 88., 89.], + [90., 91., 92., 93., 94.], + [95., 96., 97., 98., 99.]], + + [[80., 81., 82., 83., 84.], + [85., 86., 87., 88., 89.], + [90., 91., 92., 93., 94.], + [95., 96., 97., 98., 99.]], + + [[80., 81., 82., 83., 84.], + [85., 86., 87., 88., 89.], + [90., 91., 92., 93., 94.], + [95., 96., 97., 98., 99.]], + + [[80., 81., 82., 83., 84.], + [85., 86., 87., 88., 89.], + [90., 91., 92., 93., 94.], + [95., 96., 97., 98., 99.]], + + [[80., 81., 82., 83., 84.], + [85., 86., 87., 88., 89.], + [90., 91., 92., 93., 94.], + [95., 96., 97., 98., 99.]]]]], + + [[[[[80., 81., 82., 83., 84.], + [85., 86., 87., 88., 89.], + [90., 91., 92., 93., 94.], + [95., 96., 97., 98., 99.]], + + [[80., 81., 82., 83., 84.], + [85., 86., 87., 88., 89.], + [90., 91., 92., 93., 94.], + [95., 96., 97., 98., 99.]], + + [[80., 81., 82., 83., 84.], + [85., 86., 87., 88., 89.], + [90., 91., 92., 93., 94.], + [95., 96., 97., 98., 99.]], + + [[80., 81., 82., 83., 84.], + [85., 86., 87., 88., 89.], + [90., 91., 92., 93., 94.], + [95., 96., 97., 98., 99.]], + + [[80., 81., 82., 83., 84.], + [85., 86., 87., 88., 89.], + [90., 91., 92., 93., 94.], + [95., 96., 97., 98., 99.]]], + + [[[80., 81., 82., 83., 84.], + [85., 86., 87., 88., 89.], + [90., 91., 92., 93., 94.], + [95., 96., 97., 98., 99.]], + + [[80., 81., 82., 83., 84.], + [85., 86., 87., 88., 89.], + [90., 91., 92., 93., 94.], + [95., 96., 97., 98., 99.]], + + [[80., 81., 82., 83., 84.], + [85., 86., 87., 88., 89.], + [90., 91., 92., 93., 94.], + [95., 96., 97., 98., 99.]], + + [[80., 81., 82., 83., 84.], + [85., 86., 87., 88., 89.], + [90., 91., 92., 93., 94.], + [95., 96., 97., 98., 99.]], + + [[80., 81., 82., 83., 84.], + [85., 86., 87., 88., 89.], + [90., 91., 92., 93., 94.], + [95., 96., 97., 98., 99.]]], + + [[[80., 81., 82., 83., 84.], + [85., 86., 87., 88., 89.], + [90., 91., 92., 93., 94.], + [95., 96., 97., 98., 99.]], + + [[80., 81., 82., 83., 84.], + [85., 86., 87., 88., 89.], + [90., 91., 92., 93., 94.], + [95., 96., 97., 98., 99.]], + + [[80., 81., 82., 83., 84.], + [85., 86., 87., 88., 89.], + [90., 91., 92., 93., 94.], + [95., 96., 97., 98., 99.]], + + [[80., 81., 82., 83., 84.], + [85., 86., 87., 88., 89.], + [90., 91., 92., 93., 94.], + [95., 96., 97., 98., 99.]], + + [[80., 81., 82., 83., 84.], + [85., 86., 87., 88., 89.], + [90., 91., 92., 93., 94.], + [95., 96., 97., 98., 99.]]], + + [[[80., 81., 82., 83., 84.], + [85., 86., 87., 88., 89.], + [90., 91., 92., 93., 94.], + [95., 96., 97., 98., 99.]], + + [[80., 81., 82., 83., 84.], + [85., 86., 87., 88., 89.], + [90., 91., 92., 93., 94.], + [95., 96., 97., 98., 99.]], + + [[80., 81., 82., 83., 84.], + [85., 86., 87., 88., 89.], + [90., 91., 92., 93., 94.], + [95., 96., 97., 98., 99.]], + + [[80., 81., 82., 83., 84.], + [85., 86., 87., 88., 89.], + [90., 91., 92., 93., 94.], + [95., 96., 97., 98., 99.]], + + [[80., 81., 82., 83., 84.], + [85., 86., 87., 88., 89.], + [90., 91., 92., 93., 94.], + [95., 96., 97., 98., 99.]]]], + + [[[[80., 81., 82., 83., 84.], + [85., 86., 87., 88., 89.], + [90., 91., 92., 93., 94.], + [95., 96., 97., 98., 99.]], + + [[80., 81., 82., 83., 84.], + [85., 86., 87., 88., 89.], + [90., 91., 92., 93., 94.], + [95., 96., 97., 98., 99.]], + + [[80., 81., 82., 83., 84.], + [85., 86., 87., 88., 89.], + [90., 91., 92., 93., 94.], + [95., 96., 97., 98., 99.]], + + [[80., 81., 82., 83., 84.], + [85., 86., 87., 88., 89.], + [90., 91., 92., 93., 94.], + [95., 96., 97., 98., 99.]], + + [[80., 81., 82., 83., 84.], + [85., 86., 87., 88., 89.], + [90., 91., 92., 93., 94.], + [95., 96., 97., 98., 99.]]], + + [[[80., 81., 82., 83., 84.], + [85., 86., 87., 88., 89.], + [90., 91., 92., 93., 94.], + [95., 96., 97., 98., 99.]], + + [[80., 81., 82., 83., 84.], + [85., 86., 87., 88., 89.], + [90., 91., 92., 93., 94.], + [95., 96., 97., 98., 99.]], + + [[80., 81., 82., 83., 84.], + [85., 86., 87., 88., 89.], + [90., 91., 92., 93., 94.], + [95., 96., 97., 98., 99.]], + + [[80., 81., 82., 83., 84.], + [85., 86., 87., 88., 89.], + [90., 91., 92., 93., 94.], + [95., 96., 97., 98., 99.]], + + [[80., 81., 82., 83., 84.], + [85., 86., 87., 88., 89.], + [90., 91., 92., 93., 94.], + [95., 96., 97., 98., 99.]]], + + [[[80., 81., 82., 83., 84.], + [85., 86., 87., 88., 89.], + [90., 91., 92., 93., 94.], + [95., 96., 97., 98., 99.]], + + [[80., 81., 82., 83., 84.], + [85., 86., 87., 88., 89.], + [90., 91., 92., 93., 94.], + [95., 96., 97., 98., 99.]], + + [[80., 81., 82., 83., 84.], + [85., 86., 87., 88., 89.], + [90., 91., 92., 93., 94.], + [95., 96., 97., 98., 99.]], + + [[80., 81., 82., 83., 84.], + [85., 86., 87., 88., 89.], + [90., 91., 92., 93., 94.], + [95., 96., 97., 98., 99.]], + + [[80., 81., 82., 83., 84.], + [85., 86., 87., 88., 89.], + [90., 91., 92., 93., 94.], + [95., 96., 97., 98., 99.]]], + + [[[80., 81., 82., 83., 84.], + [85., 86., 87., 88., 89.], + [90., 91., 92., 93., 94.], + [95., 96., 97., 98., 99.]], + + [[80., 81., 82., 83., 84.], + [85., 86., 87., 88., 89.], + [90., 91., 92., 93., 94.], + [95., 96., 97., 98., 99.]], + + [[80., 81., 82., 83., 84.], + [85., 86., 87., 88., 89.], + [90., 91., 92., 93., 94.], + [95., 96., 97., 98., 99.]], + + [[80., 81., 82., 83., 84.], + [85., 86., 87., 88., 89.], + [90., 91., 92., 93., 94.], + [95., 96., 97., 98., 99.]], + + [[80., 81., 82., 83., 84.], + [85., 86., 87., 88., 89.], + [90., 91., 92., 93., 94.], + [95., 96., 97., 98., 99.]]]]]]]) + + context.set_context(mode=context.GRAPH_MODE, device_target="GPU") + gather = SparseGatherNet() + output = gather(x, indices) + error = np.ones(shape=output.asnumpy().shape) * 1.0e-6 + diff = output.asnumpy() - expect + assert np.all(diff < error) + assert np.all(-diff < error) + + +class SparseGatherNet1(nn.Cell): + def __init__(self): + super(SparseGatherNet1, self).__init__() + self.gather = P.SparseGatherV2() + + def construct(self, x, indices): + return self.gather(x, indices, -1) + + +@pytest.mark.level0 +@pytest.mark.platform_x86_gpu_training +@pytest.mark.env_onecard +def test_gather1(): + x = Tensor(np.arange(2 * 3 * 4 * 5, dtype=np.float32).reshape(2, 3, 4, 5)) + indices = Tensor(np.array([1, 3, 4], dtype='i4')) + expect = np.array([[[[1., 3., 4.], + [6., 8., 9.], + [11., 13., 14.], + [16., 18., 19.]], + + [[21., 23., 24.], + [26., 28., 29.], + [31., 33., 34.], + [36., 38., 39.]], + + [[41., 43., 44.], + [46., 48., 49.], + [51., 53., 54.], + [56., 58., 59.]]], + + [[[61., 63., 64.], + [66., 68., 69.], + [71., 73., 74.], + [76., 78., 79.]], + + [[81., 83., 84.], + [86., 88., 89.], + [91., 93., 94.], + [96., 98., 99.]], + + [[101., 103., 104.], + [106., 108., 109.], + [111., 113., 114.], + [116., 118., 119.]]]]) + + context.set_context(mode=context.GRAPH_MODE, device_target="GPU") + gather = SparseGatherNet1() + output = gather(x, indices) + error = np.ones(shape=output.asnumpy().shape) * 1.0e-6 + diff = output.asnumpy() - expect + assert np.all(diff < error) + assert np.all(-diff < error) + + +class SparseGatherNet2(nn.Cell): + def __init__(self): + super(SparseGatherNet2, self).__init__() + self.gather = P.SparseGatherV2() + + def construct(self, x, indices): + return self.gather(x, indices, 0) + + +@pytest.mark.level0 +@pytest.mark.platform_x86_gpu_training +@pytest.mark.env_onecard +def test_gather2(): + x = Tensor(np.array([[4., 5., 4., 1., 5.,], + [4., 9., 5., 6., 4.,], + [9., 8., 4., 3., 6.,], + [0., 4., 2., 2., 8.,], + [1., 8., 6., 2., 8.,], + [8., 1., 9., 7., 3.,], + [7., 9., 2., 5., 7.,], + [9., 8., 6., 8., 5.,], + [3., 7., 2., 7., 4.,], + [4., 2., 8., 2., 9.,]] + ).astype(np.float32)) + + indices = Tensor(np.array([[4000, 1, 300000]]).astype(np.int32)) + expect = np.array([[[0., 0., 0., 0., 0.], + [4., 9., 5., 6., 4.], + [0., 0., 0., 0., 0.]]]) + + context.set_context(mode=context.GRAPH_MODE, device_target="GPU") + gather = SparseGatherNet2() + output = gather(x, indices) + error = np.ones(shape=output.asnumpy().shape) * 1.0e-6 + diff = output.asnumpy() - expect + assert np.all(diff < error) + assert np.all(-diff < error)