From 21c96b3c31660cf030020a8dabc73d46946f0dc7 Mon Sep 17 00:00:00 2001 From: zhouyuanshen Date: Fri, 30 Oct 2020 17:19:12 +0800 Subject: [PATCH] add supports to op gathergrad on gpu --- .../gpu/arrays/gather_grad_gpu_kernel.cc | 38 ++++ .../gpu/arrays/gather_grad_gpu_kernel.h | 124 +++++++++++++ .../gpu/cuda_impl/gather_grad.cu | 59 +++++++ .../gpu/cuda_impl/gather_grad.cuh | 23 +++ mindspore/ops/_grad/grad_array_ops.py | 9 + mindspore/ops/operations/_grad_ops.py | 17 ++ tests/st/ops/gpu/test_gather_grad_op.py | 163 ++++++++++++++++++ 7 files changed, 433 insertions(+) create mode 100644 mindspore/ccsrc/backend/kernel_compiler/gpu/arrays/gather_grad_gpu_kernel.cc create mode 100644 mindspore/ccsrc/backend/kernel_compiler/gpu/arrays/gather_grad_gpu_kernel.h create mode 100755 mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/gather_grad.cu create mode 100644 mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/gather_grad.cuh create mode 100644 tests/st/ops/gpu/test_gather_grad_op.py diff --git a/mindspore/ccsrc/backend/kernel_compiler/gpu/arrays/gather_grad_gpu_kernel.cc b/mindspore/ccsrc/backend/kernel_compiler/gpu/arrays/gather_grad_gpu_kernel.cc new file mode 100644 index 0000000000..8991c95379 --- /dev/null +++ b/mindspore/ccsrc/backend/kernel_compiler/gpu/arrays/gather_grad_gpu_kernel.cc @@ -0,0 +1,38 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "backend/kernel_compiler/gpu/arrays/gather_grad_gpu_kernel.h" + +namespace mindspore { +namespace kernel { +MS_REG_GPU_KERNEL_TWO( + GatherDGrad, + KernelAttr().AddInputAttr(kNumberTypeInt32).AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32), + GatherGradGpuKernel, int, float) +MS_REG_GPU_KERNEL_TWO( + GatherDGrad, + KernelAttr().AddInputAttr(kNumberTypeInt64).AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32), + GatherGradGpuKernel, int64_t, float) +MS_REG_GPU_KERNEL_TWO( + GatherDGrad, + KernelAttr().AddInputAttr(kNumberTypeInt32).AddInputAttr(kNumberTypeFloat16).AddOutputAttr(kNumberTypeFloat16), + GatherGradGpuKernel, int, half) +MS_REG_GPU_KERNEL_TWO( + GatherDGrad, + KernelAttr().AddInputAttr(kNumberTypeInt64).AddInputAttr(kNumberTypeFloat16).AddOutputAttr(kNumberTypeFloat16), + GatherGradGpuKernel, int64_t, half) +} // namespace kernel +} // namespace mindspore diff --git a/mindspore/ccsrc/backend/kernel_compiler/gpu/arrays/gather_grad_gpu_kernel.h b/mindspore/ccsrc/backend/kernel_compiler/gpu/arrays/gather_grad_gpu_kernel.h new file mode 100644 index 0000000000..221ca8459d --- /dev/null +++ b/mindspore/ccsrc/backend/kernel_compiler/gpu/arrays/gather_grad_gpu_kernel.h @@ -0,0 +1,124 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_GATHER_GRAD_GPU_KERNEL_H +#define MINDSPORE_GATHER_GRAD_GPU_KERNEL_H + +#include +#include "backend/kernel_compiler/gpu/gpu_kernel.h" +#include "backend/kernel_compiler/gpu/gpu_kernel_factory.h" +#include "backend/kernel_compiler/gpu/cuda_impl/gather_grad.cuh" + +namespace mindspore { +namespace kernel { +template +class GatherGradGpuKernel : public GpuKernel { + public: + GatherGradGpuKernel() : axis_(0), handle_(nullptr) {} + ~GatherGradGpuKernel() = default; + + const std::vector &GetInputSizeList() const override { return input_size_list_; } + const std::vector &GetOutputSizeList() const override { return output_size_list_; } + const std::vector &GetWorkspaceSizeList() const override { return workspace_size_list_; } + + bool Launch(const std::vector &inputs, const std::vector &workspace, + const std::vector &outputs, void *stream_ptr) override { + VARIABLE_NOT_USED(workspace); + T *index_addr = GetDeviceAddress(inputs, 0); + S *grad_addr = GetDeviceAddress(inputs, 1); + S *output_addr = GetDeviceAddress(outputs, 0); + + GatherGrad(index_addr, grad_addr, output_addr, dims_[0], dims_[1], dims_[2], + reinterpret_cast(stream_ptr)); + return true; + } + bool Init(const CNodePtr &kernel_node) override { + InitResource(); + size_t input_num = AnfAlgo::GetInputTensorNum(kernel_node); + if (input_num != 2) { + MS_LOG(EXCEPTION) << "Argument number is " << input_num << ", but GatherGradGpuKernel needs 2."; + } + index_shapes_ = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 0); + grad_shapes_ = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 1); + output_shapes_ = AnfAlgo::GetOutputInferShape(kernel_node, 0); + + axis_ = GetAttr(kernel_node, "dim"); + if (axis_ < 0) { + axis_ = axis_ + SizeToInt(index_shapes_.size()); + } + + Reshape(); + InitSizeLists(); + return true; + } + + protected: + void InitResource() override { handle_ = device::gpu::GPUDeviceManager::GetInstance().GetCudnnHandle(); } + void InitSizeLists() override { + size_t size = GetSize(index_shapes_, true); + input_size_list_.push_back(size); + + size = GetSize(grad_shapes_, false); + input_size_list_.push_back(size); + + size = GetSize(output_shapes_, false); + output_size_list_.push_back(size); + } + + private: + void Reshape() { + size_t dim_before_axis = 1; + for (size_t i = 0; i < IntToSize(axis_); i++) { + dim_before_axis *= output_shapes_[i]; + } + size_t dim_of_indices = output_shapes_[IntToSize(axis_)]; + size_t dim_after_indices = 1; + for (size_t i = IntToSize(axis_) + 1; i < output_shapes_.size(); i++) { + dim_after_indices *= output_shapes_[i]; + } + + dims_[0] = dim_before_axis; + dims_[1] = dim_of_indices; + dims_[2] = dim_after_indices; + return; + } + size_t GetSize(const std::vector &shape, const bool flag = true) const { + if (shape.size() == 0) { + return 0; + } + size_t result = flag ? sizeof(T) : sizeof(S); + for (size_t i = 0; i < shape.size(); i++) { + result *= shape[i]; + } + return result; + } + + std::vector index_shapes_; + std::vector grad_shapes_; + std::vector output_shapes_; + + size_t dims_[3] = {}; + int axis_; + cudnnHandle_t handle_; + + std::vector input_size_list_; + std::vector output_size_list_; + std::vector workspace_size_list_; +}; +} // namespace kernel +} // namespace mindspore + +#endif // MINDSPORE_GATHER_GRAD_GPU_KERNEL_H diff --git a/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/gather_grad.cu b/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/gather_grad.cu new file mode 100755 index 0000000000..fbc37d1cc3 --- /dev/null +++ b/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/gather_grad.cu @@ -0,0 +1,59 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include +#include "backend/kernel_compiler/gpu/cuda_impl/gather_grad.cuh" +#include "runtime/device/gpu/cuda_common.h" +template +__global__ void GatherGradKernel(const T *index, const S *grad, S *output, const size_t output_dim0, + const size_t output_dim1, const size_t output_dim2) { + size_t num = output_dim0 * output_dim1 * output_dim2; + size_t i, k; + for (size_t id = blockIdx.x * blockDim.x + threadIdx.x; id < num; + id += blockDim.x * gridDim.x) { + i = id / (output_dim1 * output_dim2) % output_dim0; + k = id % output_dim2; + + size_t j_read = static_cast(index[id]); + size_t read_id = i * output_dim1 * output_dim2 + j_read * output_dim2 + k; + output[read_id] = grad[id]; + } + return; +} +template +void GatherGrad(const T *index, const S *grad, S *output, const size_t output_dim0, + const size_t output_dim1, const size_t output_dim2, cudaStream_t stream) { + size_t size = output_dim0 * output_dim1 * output_dim2; + GatherGradKernel<<>>(index, grad, output, + output_dim0, output_dim1, output_dim2); + return; +} + +template void GatherGrad(const int *index, const float *grad, float *output, + const size_t output_dim0, const size_t output_dim1, + const size_t output_dim2, cudaStream_t stream); + +template void GatherGrad(const int *index, const half *grad, half *output, + const size_t output_dim0, const size_t output_dim1, + const size_t output_dim2, cudaStream_t stream); + +template void GatherGrad(const int64_t *index, const float *grad, float *output, + const size_t output_dim0, const size_t output_dim1, + const size_t output_dim2, cudaStream_t stream); + +template void GatherGrad(const int64_t *index, const half *grad, half *output, + const size_t output_dim0, const size_t output_dim1, + const size_t output_dim2, cudaStream_t stream); diff --git a/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/gather_grad.cuh b/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/gather_grad.cuh new file mode 100644 index 0000000000..3872add999 --- /dev/null +++ b/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/gather_grad.cuh @@ -0,0 +1,23 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_GATHER_GRAD_GPU_CU_H +#define MINDSPORE_GATHER_GRAD_GPU_CU_H +template +void GatherGrad(const T *index, const S *grad, S *output, const size_t output_dim0, + const size_t output_dim1, const size_t output_dim2, cudaStream_t stream); + +#endif diff --git a/mindspore/ops/_grad/grad_array_ops.py b/mindspore/ops/_grad/grad_array_ops.py index c704969eea..d1593594f6 100644 --- a/mindspore/ops/_grad/grad_array_ops.py +++ b/mindspore/ops/_grad/grad_array_ops.py @@ -374,6 +374,15 @@ def get_bprop_gather_v2(self): return bprop +@bprop_getters.register(P.GatherD) +def get_bprop_gather_d(self): + + def bprop(x, dim, index, out, dout): + return P.GatherDGrad(dim)(index, dout) + + return bprop + + @bprop_getters.register(P.SparseGatherV2) def get_bprop_sparse_gather_v2(self): """Generate bprop for SparseGatherV2""" diff --git a/mindspore/ops/operations/_grad_ops.py b/mindspore/ops/operations/_grad_ops.py index cc61c9c4c7..f20ed6cae6 100644 --- a/mindspore/ops/operations/_grad_ops.py +++ b/mindspore/ops/operations/_grad_ops.py @@ -1218,6 +1218,23 @@ class EluGrad(PrimitiveWithInfer): return x_dtype +class GatherDGrad(PrimitiveWithInfer): + """Performs grad of GatherD operation.""" + + @prim_attr_register + def __init__(self, dim=0): + """Initialize GatherDGrad""" + validator.check_is_int(dim, int) + self.add_prim_attr("dim", dim) + self.init_prim_io_names(inputs=['index', 'grad'], outputs=['output']) + + def infer_shape(self, index_shape, grad_shape): + return grad_shape + + def infer_dtype(self, index_dtype, grad_dtype): + return grad_dtype + + class ResizeBilinearGrad(PrimitiveWithInfer): """Performs grad of ResizeBilinear operation.""" diff --git a/tests/st/ops/gpu/test_gather_grad_op.py b/tests/st/ops/gpu/test_gather_grad_op.py new file mode 100644 index 0000000000..5d8700da18 --- /dev/null +++ b/tests/st/ops/gpu/test_gather_grad_op.py @@ -0,0 +1,163 @@ +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ + +import numpy as np +import pytest + +import mindspore.context as context +import mindspore.nn as nn +import mindspore as ms +import mindspore.ops.operations._grad_ops as P +from mindspore import Tensor + +class GatherDGradNet(nn.Cell): + def __init__(self, dim=0): + super(GatherDGradNet, self).__init__() + self.gather_d_grad = P.GatherDGrad(dim) + + def construct(self, index, grad): + return self.gather_d_grad(index, grad) + +@pytest.mark.level0 +@pytest.mark.platform_x86_gpu_training +@pytest.mark.env_onecard +def test_gather_grad_graph_int32_fp32(): + context.set_context(mode=context.GRAPH_MODE, device_target="GPU") + dim = 0 + index = Tensor(np.array([[0, 1, 1, 0, 0], [1, 0, 0, 1, 1]]), ms.int32) + grad = Tensor(np.array([[0.9031, 0.0890, 0.2779, 0.3198, 0.5710], + [0.6949, 0.8439, 0.2003, 0.6868, 0.4437]]), ms.float32) + expect = np.array([[0.9031, 0.8439, 0.2003, 0.3198, 0.5710], + [0.6949, 0.0890, 0.2779, 0.6868, 0.4437]], np.float32) + net = GatherDGradNet(dim) + output = net(index, grad) + error = 1e-4 + diff = output.asnumpy() - expect + assert np.all(diff < error) + +@pytest.mark.level0 +@pytest.mark.platform_x86_gpu_training +@pytest.mark.env_onecard +def test_gather_grad_graph_int64_fp32(): + context.set_context(mode=context.GRAPH_MODE, device_target="GPU") + dim = 0 + index = Tensor(np.array([[0, 1, 1, 0, 0], [1, 0, 0, 1, 1]]), ms.int64) + grad = Tensor(np.array([[0.9031, 0.0890, 0.2779, 0.3198, 0.5710], + [0.6949, 0.8439, 0.2003, 0.6868, 0.4437]]), ms.float32) + expect = np.array([[0.9031, 0.8439, 0.2003, 0.3198, 0.5710], + [0.6949, 0.0890, 0.2779, 0.6868, 0.4437]], np.float32) + net = GatherDGradNet(dim) + output = net(index, grad) + error = 1e-4 + diff = output.asnumpy() - expect + assert np.all(diff < error) + +@pytest.mark.level0 +@pytest.mark.platform_x86_gpu_training +@pytest.mark.env_onecard +def test_gather_grad_graph_int32_fp16(): + context.set_context(mode=context.GRAPH_MODE, device_target="GPU") + dim = 0 + index = Tensor(np.array([[0, 1, 1, 0, 0], [1, 0, 0, 1, 1]]), ms.int32) + grad = Tensor(np.array([[0.9031, 0.0890, 0.2779, 0.3198, 0.5710], + [0.6949, 0.8439, 0.2003, 0.6868, 0.4437]]), ms.float16) + expect = np.array([[0.9031, 0.8439, 0.2003, 0.3198, 0.5710], + [0.6949, 0.0890, 0.2779, 0.6868, 0.4437]], np.float16) + net = GatherDGradNet(dim) + output = net(index, grad) + error = 1e-4 + diff = output.asnumpy() - expect + assert np.all(diff < error) + +@pytest.mark.level0 +@pytest.mark.platform_x86_gpu_training +@pytest.mark.env_onecard +def test_gather_grad_graph_int64_fp16(): + context.set_context(mode=context.GRAPH_MODE, device_target="GPU") + dim = 0 + index = Tensor(np.array([[0, 1, 1, 0, 0], [1, 0, 0, 1, 1]]), ms.int64) + grad = Tensor(np.array([[0.9031, 0.0890, 0.2779, 0.3198, 0.5710], + [0.6949, 0.8439, 0.2003, 0.6868, 0.4437]]), ms.float16) + expect = np.array([[0.9031, 0.8439, 0.2003, 0.3198, 0.5710], + [0.6949, 0.0890, 0.2779, 0.6868, 0.4437]], np.float16) + net = GatherDGradNet(dim) + output = net(index, grad) + error = 1e-4 + diff = output.asnumpy() - expect + assert np.all(diff < error) + +@pytest.mark.level0 +@pytest.mark.platform_x86_gpu_training +@pytest.mark.env_onecard +def test_gather_grad_pynative_int32_fp32(): + context.set_context(mode=context.PYNATIVE_MODE, device_target="GPU") + dim = 0 + index = Tensor(np.array([[0, 1, 1, 0, 0], [1, 0, 0, 1, 1]]), ms.int32) + grad = Tensor(np.array([[0.9031, 0.0890, 0.2779, 0.3198, 0.5710], + [0.6949, 0.8439, 0.2003, 0.6868, 0.4437]]), ms.float32) + expect = np.array([[0.9031, 0.8439, 0.2003, 0.3198, 0.5710], + [0.6949, 0.0890, 0.2779, 0.6868, 0.4437]], np.float32) + output = P.GatherDGrad(dim)(index, grad) + error = 1e-4 + diff = output.asnumpy() - expect + assert np.all(diff < error) + +@pytest.mark.level0 +@pytest.mark.platform_x86_gpu_training +@pytest.mark.env_onecard +def test_gather_grad_pynative_int64_fp32(): + context.set_context(mode=context.PYNATIVE_MODE, device_target="GPU") + dim = 0 + index = Tensor(np.array([[0, 1, 1, 0, 0], [1, 0, 0, 1, 1]]), ms.int64) + grad = Tensor(np.array([[0.9031, 0.0890, 0.2779, 0.3198, 0.5710], + [0.6949, 0.8439, 0.2003, 0.6868, 0.4437]]), ms.float32) + expect = np.array([[0.9031, 0.8439, 0.2003, 0.3198, 0.5710], + [0.6949, 0.0890, 0.2779, 0.6868, 0.4437]], np.float32) + output = P.GatherDGrad(dim)(index, grad) + error = 1e-4 + diff = output.asnumpy() - expect + assert np.all(diff < error) + +@pytest.mark.level0 +@pytest.mark.platform_x86_gpu_training +@pytest.mark.env_onecard +def test_gather_grad_pynative_int32_fp16(): + context.set_context(mode=context.PYNATIVE_MODE, device_target="GPU") + dim = 0 + index = Tensor(np.array([[0, 1, 1, 0, 0], [1, 0, 0, 1, 1]]), ms.int32) + grad = Tensor(np.array([[0.9031, 0.0890, 0.2779, 0.3198, 0.5710], + [0.6949, 0.8439, 0.2003, 0.6868, 0.4437]]), ms.float16) + expect = np.array([[0.9031, 0.8439, 0.2003, 0.3198, 0.5710], + [0.6949, 0.0890, 0.2779, 0.6868, 0.4437]], np.float16) + output = P.GatherDGrad(dim)(index, grad) + error = 1e-4 + diff = output.asnumpy() - expect + assert np.all(diff < error) + +@pytest.mark.level0 +@pytest.mark.platform_x86_gpu_training +@pytest.mark.env_onecard +def test_gather_grad_pynative_int64_fp16(): + context.set_context(mode=context.PYNATIVE_MODE, device_target="GPU") + dim = 0 + index = Tensor(np.array([[0, 1, 1, 0, 0], [1, 0, 0, 1, 1]]), ms.int64) + grad = Tensor(np.array([[0.9031, 0.0890, 0.2779, 0.3198, 0.5710], + [0.6949, 0.8439, 0.2003, 0.6868, 0.4437]]), ms.float16) + expect = np.array([[0.9031, 0.8439, 0.2003, 0.3198, 0.5710], + [0.6949, 0.0890, 0.2779, 0.6868, 0.4437]], np.float16) + output = P.GatherDGrad(dim)(index, grad) + error = 1e-4 + diff = output.asnumpy() - expect + assert np.all(diff < error)