From: @TFbunny Reviewed-by: @tom__chen,@robingrosman Signed-off-by: @robingrosmantags/v1.2.0-rc1
| @@ -1,5 +1,5 @@ | |||
| /** | |||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||
| * Copyright 2020-2021 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| @@ -18,6 +18,10 @@ | |||
| namespace mindspore { | |||
| namespace kernel { | |||
| MS_REG_GPU_KERNEL_TWO( | |||
| GatherNd, | |||
| KernelAttr().AddInputAttr(kNumberTypeFloat64).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeFloat64), | |||
| GatherNdGpuFwdKernel, double, int) | |||
| MS_REG_GPU_KERNEL_TWO( | |||
| GatherNd, | |||
| KernelAttr().AddInputAttr(kNumberTypeFloat32).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeFloat32), | |||
| @@ -38,6 +42,10 @@ MS_REG_GPU_KERNEL_TWO( | |||
| MS_REG_GPU_KERNEL_TWO( | |||
| GatherNd, KernelAttr().AddInputAttr(kNumberTypeBool).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeBool), | |||
| GatherNdGpuFwdKernel, bool, int) | |||
| MS_REG_GPU_KERNEL_TWO( | |||
| GatherNd, | |||
| KernelAttr().AddInputAttr(kNumberTypeFloat64).AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeFloat64), | |||
| GatherNdGpuFwdKernel, double, int64_t) | |||
| MS_REG_GPU_KERNEL_TWO( | |||
| GatherNd, | |||
| KernelAttr().AddInputAttr(kNumberTypeFloat32).AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeFloat32), | |||
| @@ -1,5 +1,5 @@ | |||
| /** | |||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||
| * Copyright 2020-2021 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| @@ -14,8 +14,8 @@ | |||
| * limitations under the License. | |||
| */ | |||
| #ifndef MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_ARRAYS_GATHERND_GPU_KERNEL_H | |||
| #define MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_ARRAYS_GATHERND_GPU_KERNEL_H | |||
| #ifndef MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_ARRAYS_GATHERND_GPU_KERNEL_H_ | |||
| #define MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_ARRAYS_GATHERND_GPU_KERNEL_H_ | |||
| #include <vector> | |||
| #include "backend/kernel_compiler/gpu/gpu_kernel.h" | |||
| @@ -171,4 +171,4 @@ class GatherNdGpuFwdKernel : public GpuKernel { | |||
| } // namespace kernel | |||
| } // namespace mindspore | |||
| #endif // MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_ARRAYS_GATHERND_GPU_KERNEL_H | |||
| #endif // MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_ARRAYS_GATHERND_GPU_KERNEL_H_ | |||
| @@ -1,5 +1,5 @@ | |||
| /** | |||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||
| * Copyright 2020-2021 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| @@ -54,6 +54,9 @@ void GatherNd(T *input, S *indices, T *output, const size_t &output_dim0, const | |||
| return; | |||
| } | |||
| template void GatherNd<double, int>(double *input, int *indices, double *output, const size_t &output_dim0, | |||
| const size_t &output_dim1, const size_t &indices_dim1, int *batch_indices, | |||
| int *batch_strides, cudaStream_t stream); | |||
| template void GatherNd<float, int>(float *input, int *indices, float *output, const size_t &output_dim0, | |||
| const size_t &output_dim1, const size_t &indices_dim1, int *batch_indices, | |||
| int *batch_strides, cudaStream_t stream); | |||
| @@ -73,6 +76,9 @@ template void GatherNd<unsigned char, int>(unsigned char *input, int *indices, u | |||
| template void GatherNd<bool, int>(bool *input, int *indices, bool *output, const size_t &output_dim0, | |||
| const size_t &output_dim1, const size_t &indices_dim1, int *batch_indices, | |||
| int *batch_strides, cudaStream_t stream); | |||
| template void GatherNd<double, int64_t>(double *input, int64_t *indices, double *output, const size_t &output_dim0, | |||
| const size_t &output_dim1, const size_t &indices_dim1, int64_t *batch_indices, | |||
| int64_t *batch_strides, cudaStream_t stream); | |||
| template void GatherNd<float, int64_t>(float *input, int64_t *indices, float *output, const size_t &output_dim0, | |||
| const size_t &output_dim1, const size_t &indices_dim1, int64_t *batch_indices, | |||
| int64_t *batch_strides, cudaStream_t stream); | |||
| @@ -1,5 +1,5 @@ | |||
| /** | |||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||
| * Copyright 2020-2021 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| @@ -14,8 +14,8 @@ | |||
| * limitations under the License. | |||
| */ | |||
| #ifndef MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_CUDA_IMPL_GATHERND_GPU_CU_H | |||
| #define MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_CUDA_IMPL_GATHERND_GPU_CU_H | |||
| #ifndef MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_CUDA_IMPL_GATHERND_CUH_ | |||
| #define MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_CUDA_IMPL_GATHERND_CUH_ | |||
| #include "runtime/device/gpu/cuda_common.h" | |||
| @@ -23,4 +23,4 @@ template <typename T, typename S> | |||
| void GatherNd(T *input, S *indices, T *output, const size_t &output_dim0, const size_t &output_dim1, | |||
| const size_t &indices_dim1, S *batch_indices, S *batch_strides, cudaStream_t stream); | |||
| #endif // MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_CUDA_IMPL_GATHERND_GPU_CU_H | |||
| #endif // MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_CUDA_IMPL_GATHERND_CUH_ | |||
| @@ -1,4 +1,4 @@ | |||
| # Copyright 2020 Huawei Technologies Co., Ltd | |||
| # Copyright 2020-2021 Huawei Technologies Co., Ltd | |||
| # | |||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||
| # you may not use this file except in compliance with the License. | |||
| @@ -40,6 +40,12 @@ def gathernd0(nptype): | |||
| assert np.array_equal(output.asnumpy(), expect) | |||
| @pytest.mark.level0 | |||
| @pytest.mark.platform_x86_gpu_training | |||
| @pytest.mark.env_onecard | |||
| def test_gathernd0_float64(): | |||
| gathernd0(np.float64) | |||
| @pytest.mark.level0 | |||
| @pytest.mark.platform_x86_gpu_training | |||
| @pytest.mark.env_onecard | |||
| @@ -110,6 +116,12 @@ def gathernd1(nptype): | |||
| assert np.array_equal(output.asnumpy(), expect) | |||
| @pytest.mark.level0 | |||
| @pytest.mark.platform_x86_gpu_training | |||
| @pytest.mark.env_onecard | |||
| def test_gathernd1_float64(): | |||
| gathernd1(np.float64) | |||
| @pytest.mark.level0 | |||
| @pytest.mark.platform_x86_gpu_training | |||
| @pytest.mark.env_onecard | |||
| @@ -163,6 +175,12 @@ def gathernd2(nptype): | |||
| assert np.array_equal(output.asnumpy(), expect) | |||
| @pytest.mark.level0 | |||
| @pytest.mark.platform_x86_gpu_training | |||
| @pytest.mark.env_onecard | |||
| def test_gathernd2_float64(): | |||
| gathernd2(np.float64) | |||
| @pytest.mark.level0 | |||
| @pytest.mark.platform_x86_gpu_training | |||
| @pytest.mark.env_onecard | |||