diff --git a/mindspore/ccsrc/backend/kernel_compiler/gpu/arrays/gather_gpu_kernel.cc b/mindspore/ccsrc/backend/kernel_compiler/gpu/arrays/gather_gpu_kernel.cc new file mode 100644 index 0000000000..54218a5275 --- /dev/null +++ b/mindspore/ccsrc/backend/kernel_compiler/gpu/arrays/gather_gpu_kernel.cc @@ -0,0 +1,38 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "backend/kernel_compiler/gpu/arrays/gather_gpu_kernel.h" + +namespace mindspore { +namespace kernel { +MS_REG_GPU_KERNEL_TWO( + GatherD, + KernelAttr().AddInputAttr(kNumberTypeFloat32).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeFloat32), + GatherGpuFwdKernel, float, int) +MS_REG_GPU_KERNEL_TWO( + GatherD, + KernelAttr().AddInputAttr(kNumberTypeFloat32).AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeFloat32), + GatherGpuFwdKernel, float, int64_t) +MS_REG_GPU_KERNEL_TWO( + GatherD, + KernelAttr().AddInputAttr(kNumberTypeFloat16).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeFloat16), + GatherGpuFwdKernel, half, int) +MS_REG_GPU_KERNEL_TWO( + GatherD, + KernelAttr().AddInputAttr(kNumberTypeFloat16).AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeFloat16), + GatherGpuFwdKernel, half, int64_t) +} // namespace kernel +} // namespace mindspore diff --git a/mindspore/ccsrc/backend/kernel_compiler/gpu/arrays/gather_gpu_kernel.h b/mindspore/ccsrc/backend/kernel_compiler/gpu/arrays/gather_gpu_kernel.h new file mode 100644 index 0000000000..870f33eb55 --- /dev/null +++ b/mindspore/ccsrc/backend/kernel_compiler/gpu/arrays/gather_gpu_kernel.h @@ -0,0 +1,123 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_GATHER_GPU_KERNEL_H +#define MINDSPORE_GATHER_GPU_KERNEL_H + +#include +#include "backend/kernel_compiler/gpu/gpu_kernel.h" +#include "backend/kernel_compiler/gpu/gpu_kernel_factory.h" +#include "backend/kernel_compiler/gpu/cuda_impl/gather.cuh" + +namespace mindspore { +namespace kernel { +template +class GatherGpuFwdKernel : public GpuKernel { + public: + GatherGpuFwdKernel() : axis_(0), handle_(nullptr) {} + ~GatherGpuFwdKernel() = default; + + const std::vector &GetInputSizeList() const override { return input_size_list_; } + const std::vector &GetOutputSizeList() const override { return output_size_list_; } + const std::vector &GetWorkspaceSizeList() const override { return workspace_size_list_; } + + bool Launch(const std::vector &inputs, const std::vector &workspace, + const std::vector &outputs, void *stream_ptr) override { + VARIABLE_NOT_USED(workspace); + T *input_addr = GetDeviceAddress(inputs, 0); + S *index_addr = GetDeviceAddress(inputs, 1); + T *output_addr = GetDeviceAddress(outputs, 0); + + Gather(input_addr, index_addr, output_addr, dims_[0], dims_[1], dims_[2], + reinterpret_cast(stream_ptr)); + return true; + } + bool Init(const CNodePtr &kernel_node) override { + InitResource(); + size_t input_num = AnfAlgo::GetInputTensorNum(kernel_node); + if (input_num != 2) { + MS_LOG(EXCEPTION) << "Argument number is " << input_num << ", but GatherGpuFwdKernel needs 2."; + } + input_shapes_ = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 0); + index_shapes_ = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 1); + output_shapes_ = AnfAlgo::GetOutputInferShape(kernel_node, 0); + + axis_ = GetAttr(kernel_node, "dim"); + if (axis_ < 0) { + axis_ = axis_ + SizeToInt(input_shapes_.size()); + } + Reshape(); + InitSizeLists(); + return true; + } + + protected: + void InitResource() override { handle_ = device::gpu::GPUDeviceManager::GetInstance().GetCudnnHandle(); } + void InitSizeLists() override { + size_t size = GetSize(input_shapes_, true); + input_size_list_.push_back(size); + + size = GetSize(index_shapes_, false); + input_size_list_.push_back(size); + + size = GetSize(output_shapes_, true); + output_size_list_.push_back(size); + } + + private: + void Reshape() { + size_t dim_before_axis = 1; + for (size_t i = 0; i < IntToSize(axis_); i++) { + dim_before_axis *= output_shapes_[i]; + } + size_t dim_of_index = output_shapes_[IntToSize(axis_)]; + size_t dim_after_index = 1; + for (size_t i = IntToSize(axis_) + 1; i < output_shapes_.size(); i++) { + dim_after_index *= output_shapes_[i]; + } + + dims_[0] = dim_before_axis; + dims_[1] = dim_of_index; + dims_[2] = dim_after_index; + return; + } + size_t GetSize(const std::vector &shape, const bool flag = true) const { + if (shape.size() == 0) { + return 0; + } + size_t result = flag ? sizeof(T) : sizeof(S); + for (size_t i = 0; i < shape.size(); i++) { + result *= shape[i]; + } + return result; + } + + std::vector input_shapes_; + std::vector index_shapes_; + std::vector output_shapes_; + + size_t dims_[3] = {}; + int axis_; + cudnnHandle_t handle_; + + std::vector input_size_list_; + std::vector output_size_list_; + std::vector workspace_size_list_; +}; +} // namespace kernel +} // namespace mindspore + +#endif // MINDSPORE_GATHER_GPU_KERNEL_H diff --git a/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/gather.cu b/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/gather.cu new file mode 100755 index 0000000000..fc908c4c5c --- /dev/null +++ b/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/gather.cu @@ -0,0 +1,52 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include +#include "backend/kernel_compiler/gpu/cuda_impl/gather.cuh" +#include "runtime/device/gpu/cuda_common.h" +template +__global__ void GatherKernel(const T *input, const S *index, T *output, const size_t output_dim0, + const size_t output_dim1, const size_t output_dim2) { + size_t num = output_dim0 * output_dim1 * output_dim2; + size_t i, k; + for (size_t id = blockIdx.x * blockDim.x + threadIdx.x; id < num; + id += blockDim.x * gridDim.x) { + i = id / (output_dim1 * output_dim2) % output_dim0; + k = id % output_dim2; + + size_t j_read = static_cast(index[id]); + size_t read_id = i * output_dim1 * output_dim2 + j_read * output_dim2 + k; + output[id] = input[read_id]; + } + return; +} +template +void Gather(const T *input, const S *index, T *output, const size_t output_dim0, const size_t output_dim1, + const size_t output_dim2, cudaStream_t stream) { + size_t size = output_dim0 * output_dim1 * output_dim2; + GatherKernel<<>>(input, index, output, output_dim0, output_dim1, + output_dim2); + return; +} + +template void Gather(const float *input, const int *index, float *output, const size_t output_dim0, + const size_t output_dim1, const size_t output_dim2, cudaStream_t stream); +template void Gather(const float *input, const int64_t *index, float *output, const size_t output_dim0, + const size_t output_dim1, const size_t output_dim2, cudaStream_t stream); +template void Gather(const half *input, const int *index, half *output, const size_t output_dim0, + const size_t output_dim1, const size_t output_dim2, cudaStream_t stream); +template void Gather(const half *input, const int64_t *index, half *output, const size_t output_dim0, + const size_t output_dim1, const size_t output_dim2, cudaStream_t stream); diff --git a/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/gather.cuh b/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/gather.cuh new file mode 100644 index 0000000000..6f3aa87a82 --- /dev/null +++ b/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/gather.cuh @@ -0,0 +1,23 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_GATHER_GPU_CU_H +#define MINDSPORE_GATHER_GPU_CU_H +template +void Gather(const T *input, const S *index, T *output, const size_t output_dim0, const size_t output_dim1, + const size_t output_dim2, cudaStream_t stream); + +#endif diff --git a/mindspore/ccsrc/backend/optimizer/pass/const_input_to_attr_registry.cc b/mindspore/ccsrc/backend/optimizer/pass/const_input_to_attr_registry.cc index a64741034a..c0fb214752 100644 --- a/mindspore/ccsrc/backend/optimizer/pass/const_input_to_attr_registry.cc +++ b/mindspore/ccsrc/backend/optimizer/pass/const_input_to_attr_registry.cc @@ -37,6 +37,7 @@ ConstInputToAttrInfoRegistry::ConstInputToAttrInfoRegistry() { Register(prim::kPrimReduceSum->name(), {1}); Register(prim::kPrimReduceMean->name(), {1}); Register(prim::kPrimGatherV2->name(), {2}); + Register(prim::kPrimGatherD->name(), {1}); Register(prim::kPrimEmbeddingLookup->name(), {2, 3, 4, 5}); Register(prim::kPrimEmbeddingLookupCommGrad->name(), {1}); Register(prim::kPrimSubscalar->name(), {1}); diff --git a/mindspore/ccsrc/backend/optimizer/pass/convert_const_input_to_attr.cc b/mindspore/ccsrc/backend/optimizer/pass/convert_const_input_to_attr.cc index a652bd3fd0..6eea6833dd 100644 --- a/mindspore/ccsrc/backend/optimizer/pass/convert_const_input_to_attr.cc +++ b/mindspore/ccsrc/backend/optimizer/pass/convert_const_input_to_attr.cc @@ -55,6 +55,12 @@ const AnfNodePtr ConvertConstInputToAttr::Process(const FuncGraphPtr &, const An continue; } } + if (AnfAlgo::GetCNodeName(cnode) == prim::kPrimGatherD->name()) { + auto ms_context = MsContext::GetInstance(); + if (ms_context->get_param(MS_CTX_DEVICE_TARGET) != kGPUDevice) { + continue; + } + } if (AnfAlgo::IsDynamicShape(cnode)) { MS_LOG(INFO) << "current node is dynamic shape " << cnode->fullname_with_scope(); continue; diff --git a/mindspore/ccsrc/pipeline/pynative/pynative_execute.cc b/mindspore/ccsrc/pipeline/pynative/pynative_execute.cc index dec339155f..b4152e8a39 100644 --- a/mindspore/ccsrc/pipeline/pynative/pynative_execute.cc +++ b/mindspore/ccsrc/pipeline/pynative/pynative_execute.cc @@ -1,5 +1,5 @@ /** - * Copyright 2019 Huawei Technologies Co., Ltd + * Copyright 2019-2020 Huawei Technologies Co., Ltd * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -590,6 +590,13 @@ void ConstructInputTensor(const OpExecInfoPtr &op_run_info, std::vector *te if (op_run_info->op_name == prim::kPrimEmbeddingLookup->name()) { reg_exist = false; } + if (op_run_info->op_name == prim::kPrimGatherD->name()) { + auto ms_context = MsContext::GetInstance(); + // Gather op needs converting const input to attr on GPU device + if (ms_context->get_param(MS_CTX_DEVICE_TARGET) != kGPUDevice) { + reg_exist = false; + } + } op_prim->BeginRecordAddAttr(); size_t input_num = op_run_info->op_inputs.size(); diff --git a/mindspore/core/base/core_ops.h b/mindspore/core/base/core_ops.h index fcaa398027..db8fa1ab6f 100644 --- a/mindspore/core/base/core_ops.h +++ b/mindspore/core/base/core_ops.h @@ -84,6 +84,7 @@ inline const PrimitivePtr kPrimConcat = std::make_shared("Concat"); inline const PrimitivePtr kPrimSqueeze = std::make_shared("Squeeze"); inline const PrimitivePtr kPrimTranspose = std::make_shared("Transpose"); inline const PrimitivePtr kPrimGatherV2 = std::make_shared("GatherV2"); +inline const PrimitivePtr kPrimGatherD = std::make_shared("GatherD"); inline const PrimitivePtr kPrimSparseGatherV2 = std::make_shared("SparseGatherV2"); inline const PrimitivePtr kPrimShape = std::make_shared("Shape"); inline const PrimitivePtr kPrimDynamicShape = std::make_shared("DynamicShape"); diff --git a/mindspore/ops/operations/array_ops.py b/mindspore/ops/operations/array_ops.py index de98f8f785..3f5510f74c 100644 --- a/mindspore/ops/operations/array_ops.py +++ b/mindspore/ops/operations/array_ops.py @@ -4063,6 +4063,7 @@ class GatherD(PrimitiveWithInfer): @prim_attr_register def __init__(self): """Initialize GatherD""" + self.init_prim_io_names(inputs=['x', 'dim', 'index'], outputs=['output']) def __infer__(self, x, dim, index): validator.check_subclass("x", x['dtype'], mstype.tensor, self.name) diff --git a/tests/st/ops/gpu/test_gatherV2_op.py b/tests/st/ops/gpu/test_gatherV2_op.py new file mode 100644 index 0000000000..dae1cfb62c --- /dev/null +++ b/tests/st/ops/gpu/test_gatherV2_op.py @@ -0,0 +1,939 @@ +# Copyright 2019-2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ + +import numpy as np +import pytest + +import mindspore.context as context +import mindspore.nn as nn +from mindspore import Tensor +from mindspore.ops import operations as P + + +class GatherNet(nn.Cell): + def __init__(self): + super(GatherNet, self).__init__() + self.gather = P.GatherV2() + + def construct(self, x, indices): + return self.gather(x, indices, 1) + + +@pytest.mark.level0 +@pytest.mark.platform_x86_gpu_training +@pytest.mark.env_onecard +def test_gather0(): + x = Tensor(np.arange(2 * 3 * 4 * 5, dtype=np.float32).reshape(2, 3, 4, 5)) + indices = Tensor(np.ones((2, 2, 4, 5), dtype='i4')) + expect = np.array([[[[[[[20., 21., 22., 23., 24.], + [25., 26., 27., 28., 29.], + [30., 31., 32., 33., 34.], + [35., 36., 37., 38., 39.]], + + [[20., 21., 22., 23., 24.], + [25., 26., 27., 28., 29.], + [30., 31., 32., 33., 34.], + [35., 36., 37., 38., 39.]], + + [[20., 21., 22., 23., 24.], + [25., 26., 27., 28., 29.], + [30., 31., 32., 33., 34.], + [35., 36., 37., 38., 39.]], + + [[20., 21., 22., 23., 24.], + [25., 26., 27., 28., 29.], + [30., 31., 32., 33., 34.], + [35., 36., 37., 38., 39.]], + + [[20., 21., 22., 23., 24.], + [25., 26., 27., 28., 29.], + [30., 31., 32., 33., 34.], + [35., 36., 37., 38., 39.]]], + + [[[20., 21., 22., 23., 24.], + [25., 26., 27., 28., 29.], + [30., 31., 32., 33., 34.], + [35., 36., 37., 38., 39.]], + + [[20., 21., 22., 23., 24.], + [25., 26., 27., 28., 29.], + [30., 31., 32., 33., 34.], + [35., 36., 37., 38., 39.]], + + [[20., 21., 22., 23., 24.], + [25., 26., 27., 28., 29.], + [30., 31., 32., 33., 34.], + [35., 36., 37., 38., 39.]], + + [[20., 21., 22., 23., 24.], + [25., 26., 27., 28., 29.], + [30., 31., 32., 33., 34.], + [35., 36., 37., 38., 39.]], + + [[20., 21., 22., 23., 24.], + [25., 26., 27., 28., 29.], + [30., 31., 32., 33., 34.], + [35., 36., 37., 38., 39.]]], + + [[[20., 21., 22., 23., 24.], + [25., 26., 27., 28., 29.], + [30., 31., 32., 33., 34.], + [35., 36., 37., 38., 39.]], + + [[20., 21., 22., 23., 24.], + [25., 26., 27., 28., 29.], + [30., 31., 32., 33., 34.], + [35., 36., 37., 38., 39.]], + + [[20., 21., 22., 23., 24.], + [25., 26., 27., 28., 29.], + [30., 31., 32., 33., 34.], + [35., 36., 37., 38., 39.]], + + [[20., 21., 22., 23., 24.], + [25., 26., 27., 28., 29.], + [30., 31., 32., 33., 34.], + [35., 36., 37., 38., 39.]], + + [[20., 21., 22., 23., 24.], + [25., 26., 27., 28., 29.], + [30., 31., 32., 33., 34.], + [35., 36., 37., 38., 39.]]], + + [[[20., 21., 22., 23., 24.], + [25., 26., 27., 28., 29.], + [30., 31., 32., 33., 34.], + [35., 36., 37., 38., 39.]], + + [[20., 21., 22., 23., 24.], + [25., 26., 27., 28., 29.], + [30., 31., 32., 33., 34.], + [35., 36., 37., 38., 39.]], + + [[20., 21., 22., 23., 24.], + [25., 26., 27., 28., 29.], + [30., 31., 32., 33., 34.], + [35., 36., 37., 38., 39.]], + + [[20., 21., 22., 23., 24.], + [25., 26., 27., 28., 29.], + [30., 31., 32., 33., 34.], + [35., 36., 37., 38., 39.]], + + [[20., 21., 22., 23., 24.], + [25., 26., 27., 28., 29.], + [30., 31., 32., 33., 34.], + [35., 36., 37., 38., 39.]]]], + + [[[[20., 21., 22., 23., 24.], + [25., 26., 27., 28., 29.], + [30., 31., 32., 33., 34.], + [35., 36., 37., 38., 39.]], + + [[20., 21., 22., 23., 24.], + [25., 26., 27., 28., 29.], + [30., 31., 32., 33., 34.], + [35., 36., 37., 38., 39.]], + + [[20., 21., 22., 23., 24.], + [25., 26., 27., 28., 29.], + [30., 31., 32., 33., 34.], + [35., 36., 37., 38., 39.]], + + [[20., 21., 22., 23., 24.], + [25., 26., 27., 28., 29.], + [30., 31., 32., 33., 34.], + [35., 36., 37., 38., 39.]], + + [[20., 21., 22., 23., 24.], + [25., 26., 27., 28., 29.], + [30., 31., 32., 33., 34.], + [35., 36., 37., 38., 39.]]], + + [[[20., 21., 22., 23., 24.], + [25., 26., 27., 28., 29.], + [30., 31., 32., 33., 34.], + [35., 36., 37., 38., 39.]], + + [[20., 21., 22., 23., 24.], + [25., 26., 27., 28., 29.], + [30., 31., 32., 33., 34.], + [35., 36., 37., 38., 39.]], + + [[20., 21., 22., 23., 24.], + [25., 26., 27., 28., 29.], + [30., 31., 32., 33., 34.], + [35., 36., 37., 38., 39.]], + + [[20., 21., 22., 23., 24.], + [25., 26., 27., 28., 29.], + [30., 31., 32., 33., 34.], + [35., 36., 37., 38., 39.]], + + [[20., 21., 22., 23., 24.], + [25., 26., 27., 28., 29.], + [30., 31., 32., 33., 34.], + [35., 36., 37., 38., 39.]]], + + [[[20., 21., 22., 23., 24.], + [25., 26., 27., 28., 29.], + [30., 31., 32., 33., 34.], + [35., 36., 37., 38., 39.]], + + [[20., 21., 22., 23., 24.], + [25., 26., 27., 28., 29.], + [30., 31., 32., 33., 34.], + [35., 36., 37., 38., 39.]], + + [[20., 21., 22., 23., 24.], + [25., 26., 27., 28., 29.], + [30., 31., 32., 33., 34.], + [35., 36., 37., 38., 39.]], + + [[20., 21., 22., 23., 24.], + [25., 26., 27., 28., 29.], + [30., 31., 32., 33., 34.], + [35., 36., 37., 38., 39.]], + + [[20., 21., 22., 23., 24.], + [25., 26., 27., 28., 29.], + [30., 31., 32., 33., 34.], + [35., 36., 37., 38., 39.]]], + + [[[20., 21., 22., 23., 24.], + [25., 26., 27., 28., 29.], + [30., 31., 32., 33., 34.], + [35., 36., 37., 38., 39.]], + + [[20., 21., 22., 23., 24.], + [25., 26., 27., 28., 29.], + [30., 31., 32., 33., 34.], + [35., 36., 37., 38., 39.]], + + [[20., 21., 22., 23., 24.], + [25., 26., 27., 28., 29.], + [30., 31., 32., 33., 34.], + [35., 36., 37., 38., 39.]], + + [[20., 21., 22., 23., 24.], + [25., 26., 27., 28., 29.], + [30., 31., 32., 33., 34.], + [35., 36., 37., 38., 39.]], + + [[20., 21., 22., 23., 24.], + [25., 26., 27., 28., 29.], + [30., 31., 32., 33., 34.], + [35., 36., 37., 38., 39.]]]]], + + [[[[[20., 21., 22., 23., 24.], + [25., 26., 27., 28., 29.], + [30., 31., 32., 33., 34.], + [35., 36., 37., 38., 39.]], + + [[20., 21., 22., 23., 24.], + [25., 26., 27., 28., 29.], + [30., 31., 32., 33., 34.], + [35., 36., 37., 38., 39.]], + + [[20., 21., 22., 23., 24.], + [25., 26., 27., 28., 29.], + [30., 31., 32., 33., 34.], + [35., 36., 37., 38., 39.]], + + [[20., 21., 22., 23., 24.], + [25., 26., 27., 28., 29.], + [30., 31., 32., 33., 34.], + [35., 36., 37., 38., 39.]], + + [[20., 21., 22., 23., 24.], + [25., 26., 27., 28., 29.], + [30., 31., 32., 33., 34.], + [35., 36., 37., 38., 39.]]], + + [[[20., 21., 22., 23., 24.], + [25., 26., 27., 28., 29.], + [30., 31., 32., 33., 34.], + [35., 36., 37., 38., 39.]], + + [[20., 21., 22., 23., 24.], + [25., 26., 27., 28., 29.], + [30., 31., 32., 33., 34.], + [35., 36., 37., 38., 39.]], + + [[20., 21., 22., 23., 24.], + [25., 26., 27., 28., 29.], + [30., 31., 32., 33., 34.], + [35., 36., 37., 38., 39.]], + + [[20., 21., 22., 23., 24.], + [25., 26., 27., 28., 29.], + [30., 31., 32., 33., 34.], + [35., 36., 37., 38., 39.]], + + [[20., 21., 22., 23., 24.], + [25., 26., 27., 28., 29.], + [30., 31., 32., 33., 34.], + [35., 36., 37., 38., 39.]]], + + [[[20., 21., 22., 23., 24.], + [25., 26., 27., 28., 29.], + [30., 31., 32., 33., 34.], + [35., 36., 37., 38., 39.]], + + [[20., 21., 22., 23., 24.], + [25., 26., 27., 28., 29.], + [30., 31., 32., 33., 34.], + [35., 36., 37., 38., 39.]], + + [[20., 21., 22., 23., 24.], + [25., 26., 27., 28., 29.], + [30., 31., 32., 33., 34.], + [35., 36., 37., 38., 39.]], + + [[20., 21., 22., 23., 24.], + [25., 26., 27., 28., 29.], + [30., 31., 32., 33., 34.], + [35., 36., 37., 38., 39.]], + + [[20., 21., 22., 23., 24.], + [25., 26., 27., 28., 29.], + [30., 31., 32., 33., 34.], + [35., 36., 37., 38., 39.]]], + + [[[20., 21., 22., 23., 24.], + [25., 26., 27., 28., 29.], + [30., 31., 32., 33., 34.], + [35., 36., 37., 38., 39.]], + + [[20., 21., 22., 23., 24.], + [25., 26., 27., 28., 29.], + [30., 31., 32., 33., 34.], + [35., 36., 37., 38., 39.]], + + [[20., 21., 22., 23., 24.], + [25., 26., 27., 28., 29.], + [30., 31., 32., 33., 34.], + [35., 36., 37., 38., 39.]], + + [[20., 21., 22., 23., 24.], + [25., 26., 27., 28., 29.], + [30., 31., 32., 33., 34.], + [35., 36., 37., 38., 39.]], + + [[20., 21., 22., 23., 24.], + [25., 26., 27., 28., 29.], + [30., 31., 32., 33., 34.], + [35., 36., 37., 38., 39.]]]], + + [[[[20., 21., 22., 23., 24.], + [25., 26., 27., 28., 29.], + [30., 31., 32., 33., 34.], + [35., 36., 37., 38., 39.]], + + [[20., 21., 22., 23., 24.], + [25., 26., 27., 28., 29.], + [30., 31., 32., 33., 34.], + [35., 36., 37., 38., 39.]], + + [[20., 21., 22., 23., 24.], + [25., 26., 27., 28., 29.], + [30., 31., 32., 33., 34.], + [35., 36., 37., 38., 39.]], + + [[20., 21., 22., 23., 24.], + [25., 26., 27., 28., 29.], + [30., 31., 32., 33., 34.], + [35., 36., 37., 38., 39.]], + + [[20., 21., 22., 23., 24.], + [25., 26., 27., 28., 29.], + [30., 31., 32., 33., 34.], + [35., 36., 37., 38., 39.]]], + + [[[20., 21., 22., 23., 24.], + [25., 26., 27., 28., 29.], + [30., 31., 32., 33., 34.], + [35., 36., 37., 38., 39.]], + + [[20., 21., 22., 23., 24.], + [25., 26., 27., 28., 29.], + [30., 31., 32., 33., 34.], + [35., 36., 37., 38., 39.]], + + [[20., 21., 22., 23., 24.], + [25., 26., 27., 28., 29.], + [30., 31., 32., 33., 34.], + [35., 36., 37., 38., 39.]], + + [[20., 21., 22., 23., 24.], + [25., 26., 27., 28., 29.], + [30., 31., 32., 33., 34.], + [35., 36., 37., 38., 39.]], + + [[20., 21., 22., 23., 24.], + [25., 26., 27., 28., 29.], + [30., 31., 32., 33., 34.], + [35., 36., 37., 38., 39.]]], + + [[[20., 21., 22., 23., 24.], + [25., 26., 27., 28., 29.], + [30., 31., 32., 33., 34.], + [35., 36., 37., 38., 39.]], + + [[20., 21., 22., 23., 24.], + [25., 26., 27., 28., 29.], + [30., 31., 32., 33., 34.], + [35., 36., 37., 38., 39.]], + + [[20., 21., 22., 23., 24.], + [25., 26., 27., 28., 29.], + [30., 31., 32., 33., 34.], + [35., 36., 37., 38., 39.]], + + [[20., 21., 22., 23., 24.], + [25., 26., 27., 28., 29.], + [30., 31., 32., 33., 34.], + [35., 36., 37., 38., 39.]], + + [[20., 21., 22., 23., 24.], + [25., 26., 27., 28., 29.], + [30., 31., 32., 33., 34.], + [35., 36., 37., 38., 39.]]], + + [[[20., 21., 22., 23., 24.], + [25., 26., 27., 28., 29.], + [30., 31., 32., 33., 34.], + [35., 36., 37., 38., 39.]], + + [[20., 21., 22., 23., 24.], + [25., 26., 27., 28., 29.], + [30., 31., 32., 33., 34.], + [35., 36., 37., 38., 39.]], + + [[20., 21., 22., 23., 24.], + [25., 26., 27., 28., 29.], + [30., 31., 32., 33., 34.], + [35., 36., 37., 38., 39.]], + + [[20., 21., 22., 23., 24.], + [25., 26., 27., 28., 29.], + [30., 31., 32., 33., 34.], + [35., 36., 37., 38., 39.]], + + [[20., 21., 22., 23., 24.], + [25., 26., 27., 28., 29.], + [30., 31., 32., 33., 34.], + [35., 36., 37., 38., 39.]]]]]], + + [[[[[[80., 81., 82., 83., 84.], + [85., 86., 87., 88., 89.], + [90., 91., 92., 93., 94.], + [95., 96., 97., 98., 99.]], + + [[80., 81., 82., 83., 84.], + [85., 86., 87., 88., 89.], + [90., 91., 92., 93., 94.], + [95., 96., 97., 98., 99.]], + + [[80., 81., 82., 83., 84.], + [85., 86., 87., 88., 89.], + [90., 91., 92., 93., 94.], + [95., 96., 97., 98., 99.]], + + [[80., 81., 82., 83., 84.], + [85., 86., 87., 88., 89.], + [90., 91., 92., 93., 94.], + [95., 96., 97., 98., 99.]], + + [[80., 81., 82., 83., 84.], + [85., 86., 87., 88., 89.], + [90., 91., 92., 93., 94.], + [95., 96., 97., 98., 99.]]], + + [[[80., 81., 82., 83., 84.], + [85., 86., 87., 88., 89.], + [90., 91., 92., 93., 94.], + [95., 96., 97., 98., 99.]], + + [[80., 81., 82., 83., 84.], + [85., 86., 87., 88., 89.], + [90., 91., 92., 93., 94.], + [95., 96., 97., 98., 99.]], + + [[80., 81., 82., 83., 84.], + [85., 86., 87., 88., 89.], + [90., 91., 92., 93., 94.], + [95., 96., 97., 98., 99.]], + + [[80., 81., 82., 83., 84.], + [85., 86., 87., 88., 89.], + [90., 91., 92., 93., 94.], + [95., 96., 97., 98., 99.]], + + [[80., 81., 82., 83., 84.], + [85., 86., 87., 88., 89.], + [90., 91., 92., 93., 94.], + [95., 96., 97., 98., 99.]]], + + [[[80., 81., 82., 83., 84.], + [85., 86., 87., 88., 89.], + [90., 91., 92., 93., 94.], + [95., 96., 97., 98., 99.]], + + [[80., 81., 82., 83., 84.], + [85., 86., 87., 88., 89.], + [90., 91., 92., 93., 94.], + [95., 96., 97., 98., 99.]], + + [[80., 81., 82., 83., 84.], + [85., 86., 87., 88., 89.], + [90., 91., 92., 93., 94.], + [95., 96., 97., 98., 99.]], + + [[80., 81., 82., 83., 84.], + [85., 86., 87., 88., 89.], + [90., 91., 92., 93., 94.], + [95., 96., 97., 98., 99.]], + + [[80., 81., 82., 83., 84.], + [85., 86., 87., 88., 89.], + [90., 91., 92., 93., 94.], + [95., 96., 97., 98., 99.]]], + + [[[80., 81., 82., 83., 84.], + [85., 86., 87., 88., 89.], + [90., 91., 92., 93., 94.], + [95., 96., 97., 98., 99.]], + + [[80., 81., 82., 83., 84.], + [85., 86., 87., 88., 89.], + [90., 91., 92., 93., 94.], + [95., 96., 97., 98., 99.]], + + [[80., 81., 82., 83., 84.], + [85., 86., 87., 88., 89.], + [90., 91., 92., 93., 94.], + [95., 96., 97., 98., 99.]], + + [[80., 81., 82., 83., 84.], + [85., 86., 87., 88., 89.], + [90., 91., 92., 93., 94.], + [95., 96., 97., 98., 99.]], + + [[80., 81., 82., 83., 84.], + [85., 86., 87., 88., 89.], + [90., 91., 92., 93., 94.], + [95., 96., 97., 98., 99.]]]], + + [[[[80., 81., 82., 83., 84.], + [85., 86., 87., 88., 89.], + [90., 91., 92., 93., 94.], + [95., 96., 97., 98., 99.]], + + [[80., 81., 82., 83., 84.], + [85., 86., 87., 88., 89.], + [90., 91., 92., 93., 94.], + [95., 96., 97., 98., 99.]], + + [[80., 81., 82., 83., 84.], + [85., 86., 87., 88., 89.], + [90., 91., 92., 93., 94.], + [95., 96., 97., 98., 99.]], + + [[80., 81., 82., 83., 84.], + [85., 86., 87., 88., 89.], + [90., 91., 92., 93., 94.], + [95., 96., 97., 98., 99.]], + + [[80., 81., 82., 83., 84.], + [85., 86., 87., 88., 89.], + [90., 91., 92., 93., 94.], + [95., 96., 97., 98., 99.]]], + + [[[80., 81., 82., 83., 84.], + [85., 86., 87., 88., 89.], + [90., 91., 92., 93., 94.], + [95., 96., 97., 98., 99.]], + + [[80., 81., 82., 83., 84.], + [85., 86., 87., 88., 89.], + [90., 91., 92., 93., 94.], + [95., 96., 97., 98., 99.]], + + [[80., 81., 82., 83., 84.], + [85., 86., 87., 88., 89.], + [90., 91., 92., 93., 94.], + [95., 96., 97., 98., 99.]], + + [[80., 81., 82., 83., 84.], + [85., 86., 87., 88., 89.], + [90., 91., 92., 93., 94.], + [95., 96., 97., 98., 99.]], + + [[80., 81., 82., 83., 84.], + [85., 86., 87., 88., 89.], + [90., 91., 92., 93., 94.], + [95., 96., 97., 98., 99.]]], + + [[[80., 81., 82., 83., 84.], + [85., 86., 87., 88., 89.], + [90., 91., 92., 93., 94.], + [95., 96., 97., 98., 99.]], + + [[80., 81., 82., 83., 84.], + [85., 86., 87., 88., 89.], + [90., 91., 92., 93., 94.], + [95., 96., 97., 98., 99.]], + + [[80., 81., 82., 83., 84.], + [85., 86., 87., 88., 89.], + [90., 91., 92., 93., 94.], + [95., 96., 97., 98., 99.]], + + [[80., 81., 82., 83., 84.], + [85., 86., 87., 88., 89.], + [90., 91., 92., 93., 94.], + [95., 96., 97., 98., 99.]], + + [[80., 81., 82., 83., 84.], + [85., 86., 87., 88., 89.], + [90., 91., 92., 93., 94.], + [95., 96., 97., 98., 99.]]], + + [[[80., 81., 82., 83., 84.], + [85., 86., 87., 88., 89.], + [90., 91., 92., 93., 94.], + [95., 96., 97., 98., 99.]], + + [[80., 81., 82., 83., 84.], + [85., 86., 87., 88., 89.], + [90., 91., 92., 93., 94.], + [95., 96., 97., 98., 99.]], + + [[80., 81., 82., 83., 84.], + [85., 86., 87., 88., 89.], + [90., 91., 92., 93., 94.], + [95., 96., 97., 98., 99.]], + + [[80., 81., 82., 83., 84.], + [85., 86., 87., 88., 89.], + [90., 91., 92., 93., 94.], + [95., 96., 97., 98., 99.]], + + [[80., 81., 82., 83., 84.], + [85., 86., 87., 88., 89.], + [90., 91., 92., 93., 94.], + [95., 96., 97., 98., 99.]]]]], + + [[[[[80., 81., 82., 83., 84.], + [85., 86., 87., 88., 89.], + [90., 91., 92., 93., 94.], + [95., 96., 97., 98., 99.]], + + [[80., 81., 82., 83., 84.], + [85., 86., 87., 88., 89.], + [90., 91., 92., 93., 94.], + [95., 96., 97., 98., 99.]], + + [[80., 81., 82., 83., 84.], + [85., 86., 87., 88., 89.], + [90., 91., 92., 93., 94.], + [95., 96., 97., 98., 99.]], + + [[80., 81., 82., 83., 84.], + [85., 86., 87., 88., 89.], + [90., 91., 92., 93., 94.], + [95., 96., 97., 98., 99.]], + + [[80., 81., 82., 83., 84.], + [85., 86., 87., 88., 89.], + [90., 91., 92., 93., 94.], + [95., 96., 97., 98., 99.]]], + + [[[80., 81., 82., 83., 84.], + [85., 86., 87., 88., 89.], + [90., 91., 92., 93., 94.], + [95., 96., 97., 98., 99.]], + + [[80., 81., 82., 83., 84.], + [85., 86., 87., 88., 89.], + [90., 91., 92., 93., 94.], + [95., 96., 97., 98., 99.]], + + [[80., 81., 82., 83., 84.], + [85., 86., 87., 88., 89.], + [90., 91., 92., 93., 94.], + [95., 96., 97., 98., 99.]], + + [[80., 81., 82., 83., 84.], + [85., 86., 87., 88., 89.], + [90., 91., 92., 93., 94.], + [95., 96., 97., 98., 99.]], + + [[80., 81., 82., 83., 84.], + [85., 86., 87., 88., 89.], + [90., 91., 92., 93., 94.], + [95., 96., 97., 98., 99.]]], + + [[[80., 81., 82., 83., 84.], + [85., 86., 87., 88., 89.], + [90., 91., 92., 93., 94.], + [95., 96., 97., 98., 99.]], + + [[80., 81., 82., 83., 84.], + [85., 86., 87., 88., 89.], + [90., 91., 92., 93., 94.], + [95., 96., 97., 98., 99.]], + + [[80., 81., 82., 83., 84.], + [85., 86., 87., 88., 89.], + [90., 91., 92., 93., 94.], + [95., 96., 97., 98., 99.]], + + [[80., 81., 82., 83., 84.], + [85., 86., 87., 88., 89.], + [90., 91., 92., 93., 94.], + [95., 96., 97., 98., 99.]], + + [[80., 81., 82., 83., 84.], + [85., 86., 87., 88., 89.], + [90., 91., 92., 93., 94.], + [95., 96., 97., 98., 99.]]], + + [[[80., 81., 82., 83., 84.], + [85., 86., 87., 88., 89.], + [90., 91., 92., 93., 94.], + [95., 96., 97., 98., 99.]], + + [[80., 81., 82., 83., 84.], + [85., 86., 87., 88., 89.], + [90., 91., 92., 93., 94.], + [95., 96., 97., 98., 99.]], + + [[80., 81., 82., 83., 84.], + [85., 86., 87., 88., 89.], + [90., 91., 92., 93., 94.], + [95., 96., 97., 98., 99.]], + + [[80., 81., 82., 83., 84.], + [85., 86., 87., 88., 89.], + [90., 91., 92., 93., 94.], + [95., 96., 97., 98., 99.]], + + [[80., 81., 82., 83., 84.], + [85., 86., 87., 88., 89.], + [90., 91., 92., 93., 94.], + [95., 96., 97., 98., 99.]]]], + + [[[[80., 81., 82., 83., 84.], + [85., 86., 87., 88., 89.], + [90., 91., 92., 93., 94.], + [95., 96., 97., 98., 99.]], + + [[80., 81., 82., 83., 84.], + [85., 86., 87., 88., 89.], + [90., 91., 92., 93., 94.], + [95., 96., 97., 98., 99.]], + + [[80., 81., 82., 83., 84.], + [85., 86., 87., 88., 89.], + [90., 91., 92., 93., 94.], + [95., 96., 97., 98., 99.]], + + [[80., 81., 82., 83., 84.], + [85., 86., 87., 88., 89.], + [90., 91., 92., 93., 94.], + [95., 96., 97., 98., 99.]], + + [[80., 81., 82., 83., 84.], + [85., 86., 87., 88., 89.], + [90., 91., 92., 93., 94.], + [95., 96., 97., 98., 99.]]], + + [[[80., 81., 82., 83., 84.], + [85., 86., 87., 88., 89.], + [90., 91., 92., 93., 94.], + [95., 96., 97., 98., 99.]], + + [[80., 81., 82., 83., 84.], + [85., 86., 87., 88., 89.], + [90., 91., 92., 93., 94.], + [95., 96., 97., 98., 99.]], + + [[80., 81., 82., 83., 84.], + [85., 86., 87., 88., 89.], + [90., 91., 92., 93., 94.], + [95., 96., 97., 98., 99.]], + + [[80., 81., 82., 83., 84.], + [85., 86., 87., 88., 89.], + [90., 91., 92., 93., 94.], + [95., 96., 97., 98., 99.]], + + [[80., 81., 82., 83., 84.], + [85., 86., 87., 88., 89.], + [90., 91., 92., 93., 94.], + [95., 96., 97., 98., 99.]]], + + [[[80., 81., 82., 83., 84.], + [85., 86., 87., 88., 89.], + [90., 91., 92., 93., 94.], + [95., 96., 97., 98., 99.]], + + [[80., 81., 82., 83., 84.], + [85., 86., 87., 88., 89.], + [90., 91., 92., 93., 94.], + [95., 96., 97., 98., 99.]], + + [[80., 81., 82., 83., 84.], + [85., 86., 87., 88., 89.], + [90., 91., 92., 93., 94.], + [95., 96., 97., 98., 99.]], + + [[80., 81., 82., 83., 84.], + [85., 86., 87., 88., 89.], + [90., 91., 92., 93., 94.], + [95., 96., 97., 98., 99.]], + + [[80., 81., 82., 83., 84.], + [85., 86., 87., 88., 89.], + [90., 91., 92., 93., 94.], + [95., 96., 97., 98., 99.]]], + + [[[80., 81., 82., 83., 84.], + [85., 86., 87., 88., 89.], + [90., 91., 92., 93., 94.], + [95., 96., 97., 98., 99.]], + + [[80., 81., 82., 83., 84.], + [85., 86., 87., 88., 89.], + [90., 91., 92., 93., 94.], + [95., 96., 97., 98., 99.]], + + [[80., 81., 82., 83., 84.], + [85., 86., 87., 88., 89.], + [90., 91., 92., 93., 94.], + [95., 96., 97., 98., 99.]], + + [[80., 81., 82., 83., 84.], + [85., 86., 87., 88., 89.], + [90., 91., 92., 93., 94.], + [95., 96., 97., 98., 99.]], + + [[80., 81., 82., 83., 84.], + [85., 86., 87., 88., 89.], + [90., 91., 92., 93., 94.], + [95., 96., 97., 98., 99.]]]]]]]) + + context.set_context(mode=context.GRAPH_MODE, device_target="GPU") + gather = GatherNet() + output = gather(x, indices) + error = np.ones(shape=output.asnumpy().shape) * 1.0e-6 + diff = output.asnumpy() - expect + assert np.all(diff < error) + assert np.all(-diff < error) + + +class GatherNet1(nn.Cell): + def __init__(self): + super(GatherNet1, self).__init__() + self.gather = P.GatherV2() + + def construct(self, x, indices): + return self.gather(x, indices, -1) + + +@pytest.mark.level0 +@pytest.mark.platform_x86_gpu_training +@pytest.mark.env_onecard +def test_gather1(): + x = Tensor(np.arange(2 * 3 * 4 * 5, dtype=np.float32).reshape(2, 3, 4, 5)) + indices = Tensor(np.array([1, 3, 4], dtype='i4')) + expect = np.array([[[[1., 3., 4.], + [6., 8., 9.], + [11., 13., 14.], + [16., 18., 19.]], + + [[21., 23., 24.], + [26., 28., 29.], + [31., 33., 34.], + [36., 38., 39.]], + + [[41., 43., 44.], + [46., 48., 49.], + [51., 53., 54.], + [56., 58., 59.]]], + + [[[61., 63., 64.], + [66., 68., 69.], + [71., 73., 74.], + [76., 78., 79.]], + + [[81., 83., 84.], + [86., 88., 89.], + [91., 93., 94.], + [96., 98., 99.]], + + [[101., 103., 104.], + [106., 108., 109.], + [111., 113., 114.], + [116., 118., 119.]]]]) + + context.set_context(mode=context.GRAPH_MODE, device_target="GPU") + gather = GatherNet1() + output = gather(x, indices) + error = np.ones(shape=output.asnumpy().shape) * 1.0e-6 + diff = output.asnumpy() - expect + assert np.all(diff < error) + assert np.all(-diff < error) + + +class GatherNet2(nn.Cell): + def __init__(self): + super(GatherNet2, self).__init__() + self.gather = P.GatherV2() + + def construct(self, x, indices): + return self.gather(x, indices, 0) + + +@pytest.mark.level0 +@pytest.mark.platform_x86_gpu_training +@pytest.mark.env_onecard +def test_gather2(): + x = Tensor(np.array([[4., 5., 4., 1., 5.,], + [4., 9., 5., 6., 4.,], + [9., 8., 4., 3., 6.,], + [0., 4., 2., 2., 8.,], + [1., 8., 6., 2., 8.,], + [8., 1., 9., 7., 3.,], + [7., 9., 2., 5., 7.,], + [9., 8., 6., 8., 5.,], + [3., 7., 2., 7., 4.,], + [4., 2., 8., 2., 9.,]] + ).astype(np.float32)) + + indices = Tensor(np.array([[4000, 1, 300000]]).astype(np.int32)) + expect = np.array([[[0., 0., 0., 0., 0.], + [4., 9., 5., 6., 4.], + [0., 0., 0., 0., 0.]]]) + + context.set_context(mode=context.GRAPH_MODE, device_target="GPU") + gather = GatherNet2() + output = gather(x, indices) + error = np.ones(shape=output.asnumpy().shape) * 1.0e-6 + diff = output.asnumpy() - expect + assert np.all(diff < error) + assert np.all(-diff < error) diff --git a/tests/st/ops/gpu/test_gather_op.py b/tests/st/ops/gpu/test_gather_op.py index 5c882a6567..69a3cec1c8 100644 --- a/tests/st/ops/gpu/test_gather_op.py +++ b/tests/st/ops/gpu/test_gather_op.py @@ -1,4 +1,4 @@ -# Copyright 2019 Huawei Technologies Co., Ltd +# Copyright 2019-2020 Huawei Technologies Co., Ltd # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -18,922 +18,120 @@ import pytest import mindspore.context as context import mindspore.nn as nn +import mindspore as ms from mindspore import Tensor from mindspore.ops import operations as P class GatherNet(nn.Cell): - def __init__(self): + def __init__(self, dim=0): super(GatherNet, self).__init__() - self.gather = P.GatherV2() - - def construct(self, x, indices): - return self.gather(x, indices, 1) + self.gather = P.GatherD() + self.dim = dim + def construct(self, x, index): + return self.gather(x, self.dim, index) @pytest.mark.level0 @pytest.mark.platform_x86_gpu_training @pytest.mark.env_onecard -def test_gather0(): - x = Tensor(np.arange(2 * 3 * 4 * 5, dtype=np.float32).reshape(2, 3, 4, 5)) - indices = Tensor(np.ones((2, 2, 4, 5), dtype='i4')) - expect = np.array([[[[[[[20., 21., 22., 23., 24.], - [25., 26., 27., 28., 29.], - [30., 31., 32., 33., 34.], - [35., 36., 37., 38., 39.]], - - [[20., 21., 22., 23., 24.], - [25., 26., 27., 28., 29.], - [30., 31., 32., 33., 34.], - [35., 36., 37., 38., 39.]], - - [[20., 21., 22., 23., 24.], - [25., 26., 27., 28., 29.], - [30., 31., 32., 33., 34.], - [35., 36., 37., 38., 39.]], - - [[20., 21., 22., 23., 24.], - [25., 26., 27., 28., 29.], - [30., 31., 32., 33., 34.], - [35., 36., 37., 38., 39.]], - - [[20., 21., 22., 23., 24.], - [25., 26., 27., 28., 29.], - [30., 31., 32., 33., 34.], - [35., 36., 37., 38., 39.]]], - - [[[20., 21., 22., 23., 24.], - [25., 26., 27., 28., 29.], - [30., 31., 32., 33., 34.], - [35., 36., 37., 38., 39.]], - - [[20., 21., 22., 23., 24.], - [25., 26., 27., 28., 29.], - [30., 31., 32., 33., 34.], - [35., 36., 37., 38., 39.]], - - [[20., 21., 22., 23., 24.], - [25., 26., 27., 28., 29.], - [30., 31., 32., 33., 34.], - [35., 36., 37., 38., 39.]], - - [[20., 21., 22., 23., 24.], - [25., 26., 27., 28., 29.], - [30., 31., 32., 33., 34.], - [35., 36., 37., 38., 39.]], - - [[20., 21., 22., 23., 24.], - [25., 26., 27., 28., 29.], - [30., 31., 32., 33., 34.], - [35., 36., 37., 38., 39.]]], - - [[[20., 21., 22., 23., 24.], - [25., 26., 27., 28., 29.], - [30., 31., 32., 33., 34.], - [35., 36., 37., 38., 39.]], - - [[20., 21., 22., 23., 24.], - [25., 26., 27., 28., 29.], - [30., 31., 32., 33., 34.], - [35., 36., 37., 38., 39.]], - - [[20., 21., 22., 23., 24.], - [25., 26., 27., 28., 29.], - [30., 31., 32., 33., 34.], - [35., 36., 37., 38., 39.]], - - [[20., 21., 22., 23., 24.], - [25., 26., 27., 28., 29.], - [30., 31., 32., 33., 34.], - [35., 36., 37., 38., 39.]], - - [[20., 21., 22., 23., 24.], - [25., 26., 27., 28., 29.], - [30., 31., 32., 33., 34.], - [35., 36., 37., 38., 39.]]], - - [[[20., 21., 22., 23., 24.], - [25., 26., 27., 28., 29.], - [30., 31., 32., 33., 34.], - [35., 36., 37., 38., 39.]], - - [[20., 21., 22., 23., 24.], - [25., 26., 27., 28., 29.], - [30., 31., 32., 33., 34.], - [35., 36., 37., 38., 39.]], - - [[20., 21., 22., 23., 24.], - [25., 26., 27., 28., 29.], - [30., 31., 32., 33., 34.], - [35., 36., 37., 38., 39.]], - - [[20., 21., 22., 23., 24.], - [25., 26., 27., 28., 29.], - [30., 31., 32., 33., 34.], - [35., 36., 37., 38., 39.]], - - [[20., 21., 22., 23., 24.], - [25., 26., 27., 28., 29.], - [30., 31., 32., 33., 34.], - [35., 36., 37., 38., 39.]]]], - - [[[[20., 21., 22., 23., 24.], - [25., 26., 27., 28., 29.], - [30., 31., 32., 33., 34.], - [35., 36., 37., 38., 39.]], - - [[20., 21., 22., 23., 24.], - [25., 26., 27., 28., 29.], - [30., 31., 32., 33., 34.], - [35., 36., 37., 38., 39.]], - - [[20., 21., 22., 23., 24.], - [25., 26., 27., 28., 29.], - [30., 31., 32., 33., 34.], - [35., 36., 37., 38., 39.]], - - [[20., 21., 22., 23., 24.], - [25., 26., 27., 28., 29.], - [30., 31., 32., 33., 34.], - [35., 36., 37., 38., 39.]], - - [[20., 21., 22., 23., 24.], - [25., 26., 27., 28., 29.], - [30., 31., 32., 33., 34.], - [35., 36., 37., 38., 39.]]], - - [[[20., 21., 22., 23., 24.], - [25., 26., 27., 28., 29.], - [30., 31., 32., 33., 34.], - [35., 36., 37., 38., 39.]], - - [[20., 21., 22., 23., 24.], - [25., 26., 27., 28., 29.], - [30., 31., 32., 33., 34.], - [35., 36., 37., 38., 39.]], - - [[20., 21., 22., 23., 24.], - [25., 26., 27., 28., 29.], - [30., 31., 32., 33., 34.], - [35., 36., 37., 38., 39.]], - - [[20., 21., 22., 23., 24.], - [25., 26., 27., 28., 29.], - [30., 31., 32., 33., 34.], - [35., 36., 37., 38., 39.]], - - [[20., 21., 22., 23., 24.], - [25., 26., 27., 28., 29.], - [30., 31., 32., 33., 34.], - [35., 36., 37., 38., 39.]]], - - [[[20., 21., 22., 23., 24.], - [25., 26., 27., 28., 29.], - [30., 31., 32., 33., 34.], - [35., 36., 37., 38., 39.]], - - [[20., 21., 22., 23., 24.], - [25., 26., 27., 28., 29.], - [30., 31., 32., 33., 34.], - [35., 36., 37., 38., 39.]], - - [[20., 21., 22., 23., 24.], - [25., 26., 27., 28., 29.], - [30., 31., 32., 33., 34.], - [35., 36., 37., 38., 39.]], - - [[20., 21., 22., 23., 24.], - [25., 26., 27., 28., 29.], - [30., 31., 32., 33., 34.], - [35., 36., 37., 38., 39.]], - - [[20., 21., 22., 23., 24.], - [25., 26., 27., 28., 29.], - [30., 31., 32., 33., 34.], - [35., 36., 37., 38., 39.]]], - - [[[20., 21., 22., 23., 24.], - [25., 26., 27., 28., 29.], - [30., 31., 32., 33., 34.], - [35., 36., 37., 38., 39.]], - - [[20., 21., 22., 23., 24.], - [25., 26., 27., 28., 29.], - [30., 31., 32., 33., 34.], - [35., 36., 37., 38., 39.]], - - [[20., 21., 22., 23., 24.], - [25., 26., 27., 28., 29.], - [30., 31., 32., 33., 34.], - [35., 36., 37., 38., 39.]], - - [[20., 21., 22., 23., 24.], - [25., 26., 27., 28., 29.], - [30., 31., 32., 33., 34.], - [35., 36., 37., 38., 39.]], - - [[20., 21., 22., 23., 24.], - [25., 26., 27., 28., 29.], - [30., 31., 32., 33., 34.], - [35., 36., 37., 38., 39.]]]]], - - [[[[[20., 21., 22., 23., 24.], - [25., 26., 27., 28., 29.], - [30., 31., 32., 33., 34.], - [35., 36., 37., 38., 39.]], - - [[20., 21., 22., 23., 24.], - [25., 26., 27., 28., 29.], - [30., 31., 32., 33., 34.], - [35., 36., 37., 38., 39.]], - - [[20., 21., 22., 23., 24.], - [25., 26., 27., 28., 29.], - [30., 31., 32., 33., 34.], - [35., 36., 37., 38., 39.]], - - [[20., 21., 22., 23., 24.], - [25., 26., 27., 28., 29.], - [30., 31., 32., 33., 34.], - [35., 36., 37., 38., 39.]], - - [[20., 21., 22., 23., 24.], - [25., 26., 27., 28., 29.], - [30., 31., 32., 33., 34.], - [35., 36., 37., 38., 39.]]], - - [[[20., 21., 22., 23., 24.], - [25., 26., 27., 28., 29.], - [30., 31., 32., 33., 34.], - [35., 36., 37., 38., 39.]], - - [[20., 21., 22., 23., 24.], - [25., 26., 27., 28., 29.], - [30., 31., 32., 33., 34.], - [35., 36., 37., 38., 39.]], - - [[20., 21., 22., 23., 24.], - [25., 26., 27., 28., 29.], - [30., 31., 32., 33., 34.], - [35., 36., 37., 38., 39.]], - - [[20., 21., 22., 23., 24.], - [25., 26., 27., 28., 29.], - [30., 31., 32., 33., 34.], - [35., 36., 37., 38., 39.]], - - [[20., 21., 22., 23., 24.], - [25., 26., 27., 28., 29.], - [30., 31., 32., 33., 34.], - [35., 36., 37., 38., 39.]]], - - [[[20., 21., 22., 23., 24.], - [25., 26., 27., 28., 29.], - [30., 31., 32., 33., 34.], - [35., 36., 37., 38., 39.]], - - [[20., 21., 22., 23., 24.], - [25., 26., 27., 28., 29.], - [30., 31., 32., 33., 34.], - [35., 36., 37., 38., 39.]], - - [[20., 21., 22., 23., 24.], - [25., 26., 27., 28., 29.], - [30., 31., 32., 33., 34.], - [35., 36., 37., 38., 39.]], - - [[20., 21., 22., 23., 24.], - [25., 26., 27., 28., 29.], - [30., 31., 32., 33., 34.], - [35., 36., 37., 38., 39.]], - - [[20., 21., 22., 23., 24.], - [25., 26., 27., 28., 29.], - [30., 31., 32., 33., 34.], - [35., 36., 37., 38., 39.]]], - - [[[20., 21., 22., 23., 24.], - [25., 26., 27., 28., 29.], - [30., 31., 32., 33., 34.], - [35., 36., 37., 38., 39.]], - - [[20., 21., 22., 23., 24.], - [25., 26., 27., 28., 29.], - [30., 31., 32., 33., 34.], - [35., 36., 37., 38., 39.]], - - [[20., 21., 22., 23., 24.], - [25., 26., 27., 28., 29.], - [30., 31., 32., 33., 34.], - [35., 36., 37., 38., 39.]], - - [[20., 21., 22., 23., 24.], - [25., 26., 27., 28., 29.], - [30., 31., 32., 33., 34.], - [35., 36., 37., 38., 39.]], - - [[20., 21., 22., 23., 24.], - [25., 26., 27., 28., 29.], - [30., 31., 32., 33., 34.], - [35., 36., 37., 38., 39.]]]], - - [[[[20., 21., 22., 23., 24.], - [25., 26., 27., 28., 29.], - [30., 31., 32., 33., 34.], - [35., 36., 37., 38., 39.]], - - [[20., 21., 22., 23., 24.], - [25., 26., 27., 28., 29.], - [30., 31., 32., 33., 34.], - [35., 36., 37., 38., 39.]], - - [[20., 21., 22., 23., 24.], - [25., 26., 27., 28., 29.], - [30., 31., 32., 33., 34.], - [35., 36., 37., 38., 39.]], - - [[20., 21., 22., 23., 24.], - [25., 26., 27., 28., 29.], - [30., 31., 32., 33., 34.], - [35., 36., 37., 38., 39.]], - - [[20., 21., 22., 23., 24.], - [25., 26., 27., 28., 29.], - [30., 31., 32., 33., 34.], - [35., 36., 37., 38., 39.]]], - - [[[20., 21., 22., 23., 24.], - [25., 26., 27., 28., 29.], - [30., 31., 32., 33., 34.], - [35., 36., 37., 38., 39.]], - - [[20., 21., 22., 23., 24.], - [25., 26., 27., 28., 29.], - [30., 31., 32., 33., 34.], - [35., 36., 37., 38., 39.]], - - [[20., 21., 22., 23., 24.], - [25., 26., 27., 28., 29.], - [30., 31., 32., 33., 34.], - [35., 36., 37., 38., 39.]], - - [[20., 21., 22., 23., 24.], - [25., 26., 27., 28., 29.], - [30., 31., 32., 33., 34.], - [35., 36., 37., 38., 39.]], - - [[20., 21., 22., 23., 24.], - [25., 26., 27., 28., 29.], - [30., 31., 32., 33., 34.], - [35., 36., 37., 38., 39.]]], - - [[[20., 21., 22., 23., 24.], - [25., 26., 27., 28., 29.], - [30., 31., 32., 33., 34.], - [35., 36., 37., 38., 39.]], - - [[20., 21., 22., 23., 24.], - [25., 26., 27., 28., 29.], - [30., 31., 32., 33., 34.], - [35., 36., 37., 38., 39.]], - - [[20., 21., 22., 23., 24.], - [25., 26., 27., 28., 29.], - [30., 31., 32., 33., 34.], - [35., 36., 37., 38., 39.]], - - [[20., 21., 22., 23., 24.], - [25., 26., 27., 28., 29.], - [30., 31., 32., 33., 34.], - [35., 36., 37., 38., 39.]], - - [[20., 21., 22., 23., 24.], - [25., 26., 27., 28., 29.], - [30., 31., 32., 33., 34.], - [35., 36., 37., 38., 39.]]], - - [[[20., 21., 22., 23., 24.], - [25., 26., 27., 28., 29.], - [30., 31., 32., 33., 34.], - [35., 36., 37., 38., 39.]], - - [[20., 21., 22., 23., 24.], - [25., 26., 27., 28., 29.], - [30., 31., 32., 33., 34.], - [35., 36., 37., 38., 39.]], - - [[20., 21., 22., 23., 24.], - [25., 26., 27., 28., 29.], - [30., 31., 32., 33., 34.], - [35., 36., 37., 38., 39.]], - - [[20., 21., 22., 23., 24.], - [25., 26., 27., 28., 29.], - [30., 31., 32., 33., 34.], - [35., 36., 37., 38., 39.]], - - [[20., 21., 22., 23., 24.], - [25., 26., 27., 28., 29.], - [30., 31., 32., 33., 34.], - [35., 36., 37., 38., 39.]]]]]], - - [[[[[[80., 81., 82., 83., 84.], - [85., 86., 87., 88., 89.], - [90., 91., 92., 93., 94.], - [95., 96., 97., 98., 99.]], - - [[80., 81., 82., 83., 84.], - [85., 86., 87., 88., 89.], - [90., 91., 92., 93., 94.], - [95., 96., 97., 98., 99.]], - - [[80., 81., 82., 83., 84.], - [85., 86., 87., 88., 89.], - [90., 91., 92., 93., 94.], - [95., 96., 97., 98., 99.]], - - [[80., 81., 82., 83., 84.], - [85., 86., 87., 88., 89.], - [90., 91., 92., 93., 94.], - [95., 96., 97., 98., 99.]], - - [[80., 81., 82., 83., 84.], - [85., 86., 87., 88., 89.], - [90., 91., 92., 93., 94.], - [95., 96., 97., 98., 99.]]], - - [[[80., 81., 82., 83., 84.], - [85., 86., 87., 88., 89.], - [90., 91., 92., 93., 94.], - [95., 96., 97., 98., 99.]], - - [[80., 81., 82., 83., 84.], - [85., 86., 87., 88., 89.], - [90., 91., 92., 93., 94.], - [95., 96., 97., 98., 99.]], - - [[80., 81., 82., 83., 84.], - [85., 86., 87., 88., 89.], - [90., 91., 92., 93., 94.], - [95., 96., 97., 98., 99.]], - - [[80., 81., 82., 83., 84.], - [85., 86., 87., 88., 89.], - [90., 91., 92., 93., 94.], - [95., 96., 97., 98., 99.]], - - [[80., 81., 82., 83., 84.], - [85., 86., 87., 88., 89.], - [90., 91., 92., 93., 94.], - [95., 96., 97., 98., 99.]]], - - [[[80., 81., 82., 83., 84.], - [85., 86., 87., 88., 89.], - [90., 91., 92., 93., 94.], - [95., 96., 97., 98., 99.]], - - [[80., 81., 82., 83., 84.], - [85., 86., 87., 88., 89.], - [90., 91., 92., 93., 94.], - [95., 96., 97., 98., 99.]], - - [[80., 81., 82., 83., 84.], - [85., 86., 87., 88., 89.], - [90., 91., 92., 93., 94.], - [95., 96., 97., 98., 99.]], - - [[80., 81., 82., 83., 84.], - [85., 86., 87., 88., 89.], - [90., 91., 92., 93., 94.], - [95., 96., 97., 98., 99.]], - - [[80., 81., 82., 83., 84.], - [85., 86., 87., 88., 89.], - [90., 91., 92., 93., 94.], - [95., 96., 97., 98., 99.]]], - - [[[80., 81., 82., 83., 84.], - [85., 86., 87., 88., 89.], - [90., 91., 92., 93., 94.], - [95., 96., 97., 98., 99.]], - - [[80., 81., 82., 83., 84.], - [85., 86., 87., 88., 89.], - [90., 91., 92., 93., 94.], - [95., 96., 97., 98., 99.]], - - [[80., 81., 82., 83., 84.], - [85., 86., 87., 88., 89.], - [90., 91., 92., 93., 94.], - [95., 96., 97., 98., 99.]], - - [[80., 81., 82., 83., 84.], - [85., 86., 87., 88., 89.], - [90., 91., 92., 93., 94.], - [95., 96., 97., 98., 99.]], - - [[80., 81., 82., 83., 84.], - [85., 86., 87., 88., 89.], - [90., 91., 92., 93., 94.], - [95., 96., 97., 98., 99.]]]], - - [[[[80., 81., 82., 83., 84.], - [85., 86., 87., 88., 89.], - [90., 91., 92., 93., 94.], - [95., 96., 97., 98., 99.]], - - [[80., 81., 82., 83., 84.], - [85., 86., 87., 88., 89.], - [90., 91., 92., 93., 94.], - [95., 96., 97., 98., 99.]], - - [[80., 81., 82., 83., 84.], - [85., 86., 87., 88., 89.], - [90., 91., 92., 93., 94.], - [95., 96., 97., 98., 99.]], - - [[80., 81., 82., 83., 84.], - [85., 86., 87., 88., 89.], - [90., 91., 92., 93., 94.], - [95., 96., 97., 98., 99.]], - - [[80., 81., 82., 83., 84.], - [85., 86., 87., 88., 89.], - [90., 91., 92., 93., 94.], - [95., 96., 97., 98., 99.]]], - - [[[80., 81., 82., 83., 84.], - [85., 86., 87., 88., 89.], - [90., 91., 92., 93., 94.], - [95., 96., 97., 98., 99.]], - - [[80., 81., 82., 83., 84.], - [85., 86., 87., 88., 89.], - [90., 91., 92., 93., 94.], - [95., 96., 97., 98., 99.]], - - [[80., 81., 82., 83., 84.], - [85., 86., 87., 88., 89.], - [90., 91., 92., 93., 94.], - [95., 96., 97., 98., 99.]], - - [[80., 81., 82., 83., 84.], - [85., 86., 87., 88., 89.], - [90., 91., 92., 93., 94.], - [95., 96., 97., 98., 99.]], - - [[80., 81., 82., 83., 84.], - [85., 86., 87., 88., 89.], - [90., 91., 92., 93., 94.], - [95., 96., 97., 98., 99.]]], - - [[[80., 81., 82., 83., 84.], - [85., 86., 87., 88., 89.], - [90., 91., 92., 93., 94.], - [95., 96., 97., 98., 99.]], - - [[80., 81., 82., 83., 84.], - [85., 86., 87., 88., 89.], - [90., 91., 92., 93., 94.], - [95., 96., 97., 98., 99.]], - - [[80., 81., 82., 83., 84.], - [85., 86., 87., 88., 89.], - [90., 91., 92., 93., 94.], - [95., 96., 97., 98., 99.]], - - [[80., 81., 82., 83., 84.], - [85., 86., 87., 88., 89.], - [90., 91., 92., 93., 94.], - [95., 96., 97., 98., 99.]], - - [[80., 81., 82., 83., 84.], - [85., 86., 87., 88., 89.], - [90., 91., 92., 93., 94.], - [95., 96., 97., 98., 99.]]], - - [[[80., 81., 82., 83., 84.], - [85., 86., 87., 88., 89.], - [90., 91., 92., 93., 94.], - [95., 96., 97., 98., 99.]], - - [[80., 81., 82., 83., 84.], - [85., 86., 87., 88., 89.], - [90., 91., 92., 93., 94.], - [95., 96., 97., 98., 99.]], - - [[80., 81., 82., 83., 84.], - [85., 86., 87., 88., 89.], - [90., 91., 92., 93., 94.], - [95., 96., 97., 98., 99.]], - - [[80., 81., 82., 83., 84.], - [85., 86., 87., 88., 89.], - [90., 91., 92., 93., 94.], - [95., 96., 97., 98., 99.]], - - [[80., 81., 82., 83., 84.], - [85., 86., 87., 88., 89.], - [90., 91., 92., 93., 94.], - [95., 96., 97., 98., 99.]]]]], - - [[[[[80., 81., 82., 83., 84.], - [85., 86., 87., 88., 89.], - [90., 91., 92., 93., 94.], - [95., 96., 97., 98., 99.]], - - [[80., 81., 82., 83., 84.], - [85., 86., 87., 88., 89.], - [90., 91., 92., 93., 94.], - [95., 96., 97., 98., 99.]], - - [[80., 81., 82., 83., 84.], - [85., 86., 87., 88., 89.], - [90., 91., 92., 93., 94.], - [95., 96., 97., 98., 99.]], - - [[80., 81., 82., 83., 84.], - [85., 86., 87., 88., 89.], - [90., 91., 92., 93., 94.], - [95., 96., 97., 98., 99.]], - - [[80., 81., 82., 83., 84.], - [85., 86., 87., 88., 89.], - [90., 91., 92., 93., 94.], - [95., 96., 97., 98., 99.]]], - - [[[80., 81., 82., 83., 84.], - [85., 86., 87., 88., 89.], - [90., 91., 92., 93., 94.], - [95., 96., 97., 98., 99.]], - - [[80., 81., 82., 83., 84.], - [85., 86., 87., 88., 89.], - [90., 91., 92., 93., 94.], - [95., 96., 97., 98., 99.]], - - [[80., 81., 82., 83., 84.], - [85., 86., 87., 88., 89.], - [90., 91., 92., 93., 94.], - [95., 96., 97., 98., 99.]], - - [[80., 81., 82., 83., 84.], - [85., 86., 87., 88., 89.], - [90., 91., 92., 93., 94.], - [95., 96., 97., 98., 99.]], - - [[80., 81., 82., 83., 84.], - [85., 86., 87., 88., 89.], - [90., 91., 92., 93., 94.], - [95., 96., 97., 98., 99.]]], - - [[[80., 81., 82., 83., 84.], - [85., 86., 87., 88., 89.], - [90., 91., 92., 93., 94.], - [95., 96., 97., 98., 99.]], - - [[80., 81., 82., 83., 84.], - [85., 86., 87., 88., 89.], - [90., 91., 92., 93., 94.], - [95., 96., 97., 98., 99.]], - - [[80., 81., 82., 83., 84.], - [85., 86., 87., 88., 89.], - [90., 91., 92., 93., 94.], - [95., 96., 97., 98., 99.]], - - [[80., 81., 82., 83., 84.], - [85., 86., 87., 88., 89.], - [90., 91., 92., 93., 94.], - [95., 96., 97., 98., 99.]], - - [[80., 81., 82., 83., 84.], - [85., 86., 87., 88., 89.], - [90., 91., 92., 93., 94.], - [95., 96., 97., 98., 99.]]], - - [[[80., 81., 82., 83., 84.], - [85., 86., 87., 88., 89.], - [90., 91., 92., 93., 94.], - [95., 96., 97., 98., 99.]], - - [[80., 81., 82., 83., 84.], - [85., 86., 87., 88., 89.], - [90., 91., 92., 93., 94.], - [95., 96., 97., 98., 99.]], - - [[80., 81., 82., 83., 84.], - [85., 86., 87., 88., 89.], - [90., 91., 92., 93., 94.], - [95., 96., 97., 98., 99.]], - - [[80., 81., 82., 83., 84.], - [85., 86., 87., 88., 89.], - [90., 91., 92., 93., 94.], - [95., 96., 97., 98., 99.]], - - [[80., 81., 82., 83., 84.], - [85., 86., 87., 88., 89.], - [90., 91., 92., 93., 94.], - [95., 96., 97., 98., 99.]]]], - - [[[[80., 81., 82., 83., 84.], - [85., 86., 87., 88., 89.], - [90., 91., 92., 93., 94.], - [95., 96., 97., 98., 99.]], - - [[80., 81., 82., 83., 84.], - [85., 86., 87., 88., 89.], - [90., 91., 92., 93., 94.], - [95., 96., 97., 98., 99.]], - - [[80., 81., 82., 83., 84.], - [85., 86., 87., 88., 89.], - [90., 91., 92., 93., 94.], - [95., 96., 97., 98., 99.]], - - [[80., 81., 82., 83., 84.], - [85., 86., 87., 88., 89.], - [90., 91., 92., 93., 94.], - [95., 96., 97., 98., 99.]], - - [[80., 81., 82., 83., 84.], - [85., 86., 87., 88., 89.], - [90., 91., 92., 93., 94.], - [95., 96., 97., 98., 99.]]], - - [[[80., 81., 82., 83., 84.], - [85., 86., 87., 88., 89.], - [90., 91., 92., 93., 94.], - [95., 96., 97., 98., 99.]], - - [[80., 81., 82., 83., 84.], - [85., 86., 87., 88., 89.], - [90., 91., 92., 93., 94.], - [95., 96., 97., 98., 99.]], - - [[80., 81., 82., 83., 84.], - [85., 86., 87., 88., 89.], - [90., 91., 92., 93., 94.], - [95., 96., 97., 98., 99.]], - - [[80., 81., 82., 83., 84.], - [85., 86., 87., 88., 89.], - [90., 91., 92., 93., 94.], - [95., 96., 97., 98., 99.]], - - [[80., 81., 82., 83., 84.], - [85., 86., 87., 88., 89.], - [90., 91., 92., 93., 94.], - [95., 96., 97., 98., 99.]]], - - [[[80., 81., 82., 83., 84.], - [85., 86., 87., 88., 89.], - [90., 91., 92., 93., 94.], - [95., 96., 97., 98., 99.]], - - [[80., 81., 82., 83., 84.], - [85., 86., 87., 88., 89.], - [90., 91., 92., 93., 94.], - [95., 96., 97., 98., 99.]], - - [[80., 81., 82., 83., 84.], - [85., 86., 87., 88., 89.], - [90., 91., 92., 93., 94.], - [95., 96., 97., 98., 99.]], - - [[80., 81., 82., 83., 84.], - [85., 86., 87., 88., 89.], - [90., 91., 92., 93., 94.], - [95., 96., 97., 98., 99.]], - - [[80., 81., 82., 83., 84.], - [85., 86., 87., 88., 89.], - [90., 91., 92., 93., 94.], - [95., 96., 97., 98., 99.]]], - - [[[80., 81., 82., 83., 84.], - [85., 86., 87., 88., 89.], - [90., 91., 92., 93., 94.], - [95., 96., 97., 98., 99.]], - - [[80., 81., 82., 83., 84.], - [85., 86., 87., 88., 89.], - [90., 91., 92., 93., 94.], - [95., 96., 97., 98., 99.]], - - [[80., 81., 82., 83., 84.], - [85., 86., 87., 88., 89.], - [90., 91., 92., 93., 94.], - [95., 96., 97., 98., 99.]], - - [[80., 81., 82., 83., 84.], - [85., 86., 87., 88., 89.], - [90., 91., 92., 93., 94.], - [95., 96., 97., 98., 99.]], - - [[80., 81., 82., 83., 84.], - [85., 86., 87., 88., 89.], - [90., 91., 92., 93., 94.], - [95., 96., 97., 98., 99.]]]]]]]) - - context.set_context(mode=context.GRAPH_MODE, device_target="GPU") - gather = GatherNet() - output = gather(x, indices) - error = np.ones(shape=output.asnumpy().shape) * 1.0e-6 +def test_gather_pynative_fp32_int32(): + context.set_context(mode=context.PYNATIVE_MODE, device_target="GPU") + error = 1e-3 + x = Tensor(np.array([[1.303, 2.333], [3.232, 4.235]]), ms.float32) + dim = 1 + index = Tensor(np.array([[0, 0], [1, 0]]), ms.int32) + expect = np.array([[1.303, 1.303], [4.235, 3.232]], np.float32) + output = P.GatherD()(x, dim, index) diff = output.asnumpy() - expect assert np.all(diff < error) - assert np.all(-diff < error) - - -class GatherNet1(nn.Cell): - def __init__(self): - super(GatherNet1, self).__init__() - self.gather = P.GatherV2() - - def construct(self, x, indices): - return self.gather(x, indices, -1) - @pytest.mark.level0 @pytest.mark.platform_x86_gpu_training @pytest.mark.env_onecard -def test_gather1(): - x = Tensor(np.arange(2 * 3 * 4 * 5, dtype=np.float32).reshape(2, 3, 4, 5)) - indices = Tensor(np.array([1, 3, 4], dtype='i4')) - expect = np.array([[[[1., 3., 4.], - [6., 8., 9.], - [11., 13., 14.], - [16., 18., 19.]], - - [[21., 23., 24.], - [26., 28., 29.], - [31., 33., 34.], - [36., 38., 39.]], - - [[41., 43., 44.], - [46., 48., 49.], - [51., 53., 54.], - [56., 58., 59.]]], - - [[[61., 63., 64.], - [66., 68., 69.], - [71., 73., 74.], - [76., 78., 79.]], - - [[81., 83., 84.], - [86., 88., 89.], - [91., 93., 94.], - [96., 98., 99.]], - - [[101., 103., 104.], - [106., 108., 109.], - [111., 113., 114.], - [116., 118., 119.]]]]) - - context.set_context(mode=context.GRAPH_MODE, device_target="GPU") - gather = GatherNet1() - output = gather(x, indices) - error = np.ones(shape=output.asnumpy().shape) * 1.0e-6 +def test_gather_pynative_fp32_int64(): + context.set_context(mode=context.PYNATIVE_MODE, device_target="GPU") + error = 1e-3 + x = Tensor(np.array([[1.303, 2.333], [3.232, 4.235]]), ms.float32) + dim = 1 + index = Tensor(np.array([[0, 0], [1, 0]]), ms.int64) + expect = np.array([[1.303, 1.303], [4.235, 3.232]], np.float32) + output = P.GatherD()(x, dim, index) diff = output.asnumpy() - expect assert np.all(diff < error) - assert np.all(-diff < error) - - -class GatherNet2(nn.Cell): - def __init__(self): - super(GatherNet2, self).__init__() - self.gather = P.GatherV2() - - def construct(self, x, indices): - return self.gather(x, indices, 0) +@pytest.mark.level0 +@pytest.mark.platform_x86_gpu_training +@pytest.mark.env_onecard +def test_gather_pynative_fp16_int32(): + context.set_context(mode=context.PYNATIVE_MODE, device_target="GPU") + error = 1e-3 + x = Tensor(np.array([[1.303, 2.333], [3.232, 4.235]]), ms.float16) + dim = 1 + index = Tensor(np.array([[0, 0], [1, 0]]), ms.int32) + expect = np.array([[1.303, 1.303], [4.235, 3.232]], np.float16) + output = P.GatherD()(x, dim, index) + diff = output.asnumpy() - expect + assert np.all(diff < error) @pytest.mark.level0 @pytest.mark.platform_x86_gpu_training @pytest.mark.env_onecard -def test_gather2(): - x = Tensor(np.array([[4., 5., 4., 1., 5.,], - [4., 9., 5., 6., 4.,], - [9., 8., 4., 3., 6.,], - [0., 4., 2., 2., 8.,], - [1., 8., 6., 2., 8.,], - [8., 1., 9., 7., 3.,], - [7., 9., 2., 5., 7.,], - [9., 8., 6., 8., 5.,], - [3., 7., 2., 7., 4.,], - [4., 2., 8., 2., 9.,]] - ).astype(np.float32)) +def test_gather_pynative_fp16_int64(): + context.set_context(mode=context.PYNATIVE_MODE, device_target="GPU") + error = 1e-3 + x = Tensor(np.array([[1.303, 2.333], [3.232, 4.235]]), ms.float16) + dim = 1 + index = Tensor(np.array([[0, 0], [1, 0]]), ms.int64) + expect = np.array([[1.303, 1.303], [4.235, 3.232]], np.float16) + output = P.GatherD()(x, dim, index) + diff = output.asnumpy() - expect + assert np.all(diff < error) + +def test_gather_graph_fp32_int32(): + context.set_context(mode=context.GRAPH_MODE, device_target="GPU") + error = 1e-3 + x = Tensor(np.array([[1.303, 2.333], [3.232, 4.235]]), ms.float32) + dim = 1 + index = Tensor(np.array([[0, 0], [1, 0]]), ms.int32) + expect = np.array([[1.303, 1.303], [4.235, 3.232]], np.float32) + gather = GatherNet(dim) + output = gather(x, index) + diff = output.asnumpy() - expect + assert np.all(diff < error) + +def test_gather_graph_fp32_int64(): + context.set_context(mode=context.GRAPH_MODE, device_target="GPU") + error = 1e-3 + x = Tensor(np.array([[1.303, 2.333], [3.232, 4.235]]), ms.float32) + dim = 1 + index = Tensor(np.array([[0, 0], [1, 0]]), ms.int64) + expect = np.array([[1.303, 1.303], [4.235, 3.232]], np.float32) + gather = GatherNet(dim) + output = gather(x, index) + diff = output.asnumpy() - expect + assert np.all(diff < error) - indices = Tensor(np.array([[4000, 1, 300000]]).astype(np.int32)) - expect = np.array([[[0., 0., 0., 0., 0.], - [4., 9., 5., 6., 4.], - [0., 0., 0., 0., 0.]]]) +def test_gather_graph_fp16_int32(): + context.set_context(mode=context.GRAPH_MODE, device_target="GPU") + error = 1e-3 + x = Tensor(np.array([[1.303, 2.333], [3.232, 4.235]]), ms.float16) + dim = 1 + index = Tensor(np.array([[0, 0], [1, 0]]), ms.int32) + expect = np.array([[1.303, 1.303], [4.235, 3.232]], np.float16) + gather = GatherNet(dim) + output = gather(x, index) + diff = output.asnumpy() - expect + assert np.all(diff < error) +def test_gather_graph_fp16_int64(): context.set_context(mode=context.GRAPH_MODE, device_target="GPU") - gather = GatherNet2() - output = gather(x, indices) - error = np.ones(shape=output.asnumpy().shape) * 1.0e-6 + error = 1e-3 + x = Tensor(np.array([[1.303, 2.333], [3.232, 4.235]]), ms.float16) + dim = 1 + index = Tensor(np.array([[0, 0], [1, 0]]), ms.int64) + expect = np.array([[1.303, 1.303], [4.235, 3.232]], np.float16) + gather = GatherNet(dim) + output = gather(x, index) diff = output.asnumpy() - expect assert np.all(diff < error) - assert np.all(-diff < error)