| @@ -0,0 +1,38 @@ | |||
| /** | |||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #include "backend/kernel_compiler/gpu/arrays/gather_grad_gpu_kernel.h" | |||
| namespace mindspore { | |||
| namespace kernel { | |||
| MS_REG_GPU_KERNEL_TWO( | |||
| GatherDGrad, | |||
| KernelAttr().AddInputAttr(kNumberTypeInt32).AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32), | |||
| GatherGradGpuKernel, int, float) | |||
| MS_REG_GPU_KERNEL_TWO( | |||
| GatherDGrad, | |||
| KernelAttr().AddInputAttr(kNumberTypeInt64).AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32), | |||
| GatherGradGpuKernel, int64_t, float) | |||
| MS_REG_GPU_KERNEL_TWO( | |||
| GatherDGrad, | |||
| KernelAttr().AddInputAttr(kNumberTypeInt32).AddInputAttr(kNumberTypeFloat16).AddOutputAttr(kNumberTypeFloat16), | |||
| GatherGradGpuKernel, int, half) | |||
| MS_REG_GPU_KERNEL_TWO( | |||
| GatherDGrad, | |||
| KernelAttr().AddInputAttr(kNumberTypeInt64).AddInputAttr(kNumberTypeFloat16).AddOutputAttr(kNumberTypeFloat16), | |||
| GatherGradGpuKernel, int64_t, half) | |||
| } // namespace kernel | |||
| } // namespace mindspore | |||
| @@ -0,0 +1,124 @@ | |||
| /** | |||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #ifndef MINDSPORE_GATHER_GRAD_GPU_KERNEL_H | |||
| #define MINDSPORE_GATHER_GRAD_GPU_KERNEL_H | |||
| #include <vector> | |||
| #include "backend/kernel_compiler/gpu/gpu_kernel.h" | |||
| #include "backend/kernel_compiler/gpu/gpu_kernel_factory.h" | |||
| #include "backend/kernel_compiler/gpu/cuda_impl/gather_grad.cuh" | |||
| namespace mindspore { | |||
| namespace kernel { | |||
| template <typename T, typename S> | |||
| class GatherGradGpuKernel : public GpuKernel { | |||
| public: | |||
| GatherGradGpuKernel() : axis_(0), handle_(nullptr) {} | |||
| ~GatherGradGpuKernel() = default; | |||
| const std::vector<size_t> &GetInputSizeList() const override { return input_size_list_; } | |||
| const std::vector<size_t> &GetOutputSizeList() const override { return output_size_list_; } | |||
| const std::vector<size_t> &GetWorkspaceSizeList() const override { return workspace_size_list_; } | |||
| bool Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &workspace, | |||
| const std::vector<AddressPtr> &outputs, void *stream_ptr) override { | |||
| VARIABLE_NOT_USED(workspace); | |||
| T *index_addr = GetDeviceAddress<T>(inputs, 0); | |||
| S *grad_addr = GetDeviceAddress<S>(inputs, 1); | |||
| S *output_addr = GetDeviceAddress<S>(outputs, 0); | |||
| GatherGrad(index_addr, grad_addr, output_addr, dims_[0], dims_[1], dims_[2], | |||
| reinterpret_cast<cudaStream_t>(stream_ptr)); | |||
| return true; | |||
| } | |||
| bool Init(const CNodePtr &kernel_node) override { | |||
| InitResource(); | |||
| size_t input_num = AnfAlgo::GetInputTensorNum(kernel_node); | |||
| if (input_num != 2) { | |||
| MS_LOG(EXCEPTION) << "Argument number is " << input_num << ", but GatherGradGpuKernel needs 2."; | |||
| } | |||
| index_shapes_ = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 0); | |||
| grad_shapes_ = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 1); | |||
| output_shapes_ = AnfAlgo::GetOutputInferShape(kernel_node, 0); | |||
| axis_ = GetAttr<int>(kernel_node, "dim"); | |||
| if (axis_ < 0) { | |||
| axis_ = axis_ + SizeToInt(index_shapes_.size()); | |||
| } | |||
| Reshape(); | |||
| InitSizeLists(); | |||
| return true; | |||
| } | |||
| protected: | |||
| void InitResource() override { handle_ = device::gpu::GPUDeviceManager::GetInstance().GetCudnnHandle(); } | |||
| void InitSizeLists() override { | |||
| size_t size = GetSize(index_shapes_, true); | |||
| input_size_list_.push_back(size); | |||
| size = GetSize(grad_shapes_, false); | |||
| input_size_list_.push_back(size); | |||
| size = GetSize(output_shapes_, false); | |||
| output_size_list_.push_back(size); | |||
| } | |||
| private: | |||
| void Reshape() { | |||
| size_t dim_before_axis = 1; | |||
| for (size_t i = 0; i < IntToSize(axis_); i++) { | |||
| dim_before_axis *= output_shapes_[i]; | |||
| } | |||
| size_t dim_of_indices = output_shapes_[IntToSize(axis_)]; | |||
| size_t dim_after_indices = 1; | |||
| for (size_t i = IntToSize(axis_) + 1; i < output_shapes_.size(); i++) { | |||
| dim_after_indices *= output_shapes_[i]; | |||
| } | |||
| dims_[0] = dim_before_axis; | |||
| dims_[1] = dim_of_indices; | |||
| dims_[2] = dim_after_indices; | |||
| return; | |||
| } | |||
| size_t GetSize(const std::vector<size_t> &shape, const bool flag = true) const { | |||
| if (shape.size() == 0) { | |||
| return 0; | |||
| } | |||
| size_t result = flag ? sizeof(T) : sizeof(S); | |||
| for (size_t i = 0; i < shape.size(); i++) { | |||
| result *= shape[i]; | |||
| } | |||
| return result; | |||
| } | |||
| std::vector<size_t> index_shapes_; | |||
| std::vector<size_t> grad_shapes_; | |||
| std::vector<size_t> output_shapes_; | |||
| size_t dims_[3] = {}; | |||
| int axis_; | |||
| cudnnHandle_t handle_; | |||
| std::vector<size_t> input_size_list_; | |||
| std::vector<size_t> output_size_list_; | |||
| std::vector<size_t> workspace_size_list_; | |||
| }; | |||
| } // namespace kernel | |||
| } // namespace mindspore | |||
| #endif // MINDSPORE_GATHER_GRAD_GPU_KERNEL_H | |||
| @@ -0,0 +1,59 @@ | |||
| /** | |||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #include <iostream> | |||
| #include "backend/kernel_compiler/gpu/cuda_impl/gather_grad.cuh" | |||
| #include "runtime/device/gpu/cuda_common.h" | |||
| template <typename T, typename S> | |||
| __global__ void GatherGradKernel(const T *index, const S *grad, S *output, const size_t output_dim0, | |||
| const size_t output_dim1, const size_t output_dim2) { | |||
| size_t num = output_dim0 * output_dim1 * output_dim2; | |||
| size_t i, k; | |||
| for (size_t id = blockIdx.x * blockDim.x + threadIdx.x; id < num; | |||
| id += blockDim.x * gridDim.x) { | |||
| i = id / (output_dim1 * output_dim2) % output_dim0; | |||
| k = id % output_dim2; | |||
| size_t j_read = static_cast<size_t>(index[id]); | |||
| size_t read_id = i * output_dim1 * output_dim2 + j_read * output_dim2 + k; | |||
| output[read_id] = grad[id]; | |||
| } | |||
| return; | |||
| } | |||
| template <typename T, typename S> | |||
| void GatherGrad(const T *index, const S *grad, S *output, const size_t output_dim0, | |||
| const size_t output_dim1, const size_t output_dim2, cudaStream_t stream) { | |||
| size_t size = output_dim0 * output_dim1 * output_dim2; | |||
| GatherGradKernel<<<GET_BLOCKS(size), GET_THREADS, 0, stream>>>(index, grad, output, | |||
| output_dim0, output_dim1, output_dim2); | |||
| return; | |||
| } | |||
| template void GatherGrad<int, float>(const int *index, const float *grad, float *output, | |||
| const size_t output_dim0, const size_t output_dim1, | |||
| const size_t output_dim2, cudaStream_t stream); | |||
| template void GatherGrad<int, half>(const int *index, const half *grad, half *output, | |||
| const size_t output_dim0, const size_t output_dim1, | |||
| const size_t output_dim2, cudaStream_t stream); | |||
| template void GatherGrad<int64_t, float>(const int64_t *index, const float *grad, float *output, | |||
| const size_t output_dim0, const size_t output_dim1, | |||
| const size_t output_dim2, cudaStream_t stream); | |||
| template void GatherGrad<int64_t, half>(const int64_t *index, const half *grad, half *output, | |||
| const size_t output_dim0, const size_t output_dim1, | |||
| const size_t output_dim2, cudaStream_t stream); | |||
| @@ -0,0 +1,23 @@ | |||
| /** | |||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #ifndef MINDSPORE_GATHER_GRAD_GPU_CU_H | |||
| #define MINDSPORE_GATHER_GRAD_GPU_CU_H | |||
| template <typename T, typename S> | |||
| void GatherGrad(const T *index, const S *grad, S *output, const size_t output_dim0, | |||
| const size_t output_dim1, const size_t output_dim2, cudaStream_t stream); | |||
| #endif | |||
| @@ -374,6 +374,15 @@ def get_bprop_gather_v2(self): | |||
| return bprop | |||
| @bprop_getters.register(P.GatherD) | |||
| def get_bprop_gather_d(self): | |||
| def bprop(x, dim, index, out, dout): | |||
| return P.GatherDGrad(dim)(index, dout) | |||
| return bprop | |||
| @bprop_getters.register(P.SparseGatherV2) | |||
| def get_bprop_sparse_gather_v2(self): | |||
| """Generate bprop for SparseGatherV2""" | |||
| @@ -1218,6 +1218,23 @@ class EluGrad(PrimitiveWithInfer): | |||
| return x_dtype | |||
| class GatherDGrad(PrimitiveWithInfer): | |||
| """Performs grad of GatherD operation.""" | |||
| @prim_attr_register | |||
| def __init__(self, dim=0): | |||
| """Initialize GatherDGrad""" | |||
| validator.check_is_int(dim, int) | |||
| self.add_prim_attr("dim", dim) | |||
| self.init_prim_io_names(inputs=['index', 'grad'], outputs=['output']) | |||
| def infer_shape(self, index_shape, grad_shape): | |||
| return grad_shape | |||
| def infer_dtype(self, index_dtype, grad_dtype): | |||
| return grad_dtype | |||
| class ResizeBilinearGrad(PrimitiveWithInfer): | |||
| """Performs grad of ResizeBilinear operation.""" | |||
| @@ -0,0 +1,163 @@ | |||
| # Copyright 2020 Huawei Technologies Co., Ltd | |||
| # | |||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||
| # you may not use this file except in compliance with the License. | |||
| # You may obtain a copy of the License at | |||
| # | |||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||
| # | |||
| # Unless required by applicable law or agreed to in writing, software | |||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| # See the License for the specific language governing permissions and | |||
| # limitations under the License. | |||
| # ============================================================================ | |||
| import numpy as np | |||
| import pytest | |||
| import mindspore.context as context | |||
| import mindspore.nn as nn | |||
| import mindspore as ms | |||
| import mindspore.ops.operations._grad_ops as P | |||
| from mindspore import Tensor | |||
| class GatherDGradNet(nn.Cell): | |||
| def __init__(self, dim=0): | |||
| super(GatherDGradNet, self).__init__() | |||
| self.gather_d_grad = P.GatherDGrad(dim) | |||
| def construct(self, index, grad): | |||
| return self.gather_d_grad(index, grad) | |||
| @pytest.mark.level0 | |||
| @pytest.mark.platform_x86_gpu_training | |||
| @pytest.mark.env_onecard | |||
| def test_gather_grad_graph_int32_fp32(): | |||
| context.set_context(mode=context.GRAPH_MODE, device_target="GPU") | |||
| dim = 0 | |||
| index = Tensor(np.array([[0, 1, 1, 0, 0], [1, 0, 0, 1, 1]]), ms.int32) | |||
| grad = Tensor(np.array([[0.9031, 0.0890, 0.2779, 0.3198, 0.5710], | |||
| [0.6949, 0.8439, 0.2003, 0.6868, 0.4437]]), ms.float32) | |||
| expect = np.array([[0.9031, 0.8439, 0.2003, 0.3198, 0.5710], | |||
| [0.6949, 0.0890, 0.2779, 0.6868, 0.4437]], np.float32) | |||
| net = GatherDGradNet(dim) | |||
| output = net(index, grad) | |||
| error = 1e-4 | |||
| diff = output.asnumpy() - expect | |||
| assert np.all(diff < error) | |||
| @pytest.mark.level0 | |||
| @pytest.mark.platform_x86_gpu_training | |||
| @pytest.mark.env_onecard | |||
| def test_gather_grad_graph_int64_fp32(): | |||
| context.set_context(mode=context.GRAPH_MODE, device_target="GPU") | |||
| dim = 0 | |||
| index = Tensor(np.array([[0, 1, 1, 0, 0], [1, 0, 0, 1, 1]]), ms.int64) | |||
| grad = Tensor(np.array([[0.9031, 0.0890, 0.2779, 0.3198, 0.5710], | |||
| [0.6949, 0.8439, 0.2003, 0.6868, 0.4437]]), ms.float32) | |||
| expect = np.array([[0.9031, 0.8439, 0.2003, 0.3198, 0.5710], | |||
| [0.6949, 0.0890, 0.2779, 0.6868, 0.4437]], np.float32) | |||
| net = GatherDGradNet(dim) | |||
| output = net(index, grad) | |||
| error = 1e-4 | |||
| diff = output.asnumpy() - expect | |||
| assert np.all(diff < error) | |||
| @pytest.mark.level0 | |||
| @pytest.mark.platform_x86_gpu_training | |||
| @pytest.mark.env_onecard | |||
| def test_gather_grad_graph_int32_fp16(): | |||
| context.set_context(mode=context.GRAPH_MODE, device_target="GPU") | |||
| dim = 0 | |||
| index = Tensor(np.array([[0, 1, 1, 0, 0], [1, 0, 0, 1, 1]]), ms.int32) | |||
| grad = Tensor(np.array([[0.9031, 0.0890, 0.2779, 0.3198, 0.5710], | |||
| [0.6949, 0.8439, 0.2003, 0.6868, 0.4437]]), ms.float16) | |||
| expect = np.array([[0.9031, 0.8439, 0.2003, 0.3198, 0.5710], | |||
| [0.6949, 0.0890, 0.2779, 0.6868, 0.4437]], np.float16) | |||
| net = GatherDGradNet(dim) | |||
| output = net(index, grad) | |||
| error = 1e-4 | |||
| diff = output.asnumpy() - expect | |||
| assert np.all(diff < error) | |||
| @pytest.mark.level0 | |||
| @pytest.mark.platform_x86_gpu_training | |||
| @pytest.mark.env_onecard | |||
| def test_gather_grad_graph_int64_fp16(): | |||
| context.set_context(mode=context.GRAPH_MODE, device_target="GPU") | |||
| dim = 0 | |||
| index = Tensor(np.array([[0, 1, 1, 0, 0], [1, 0, 0, 1, 1]]), ms.int64) | |||
| grad = Tensor(np.array([[0.9031, 0.0890, 0.2779, 0.3198, 0.5710], | |||
| [0.6949, 0.8439, 0.2003, 0.6868, 0.4437]]), ms.float16) | |||
| expect = np.array([[0.9031, 0.8439, 0.2003, 0.3198, 0.5710], | |||
| [0.6949, 0.0890, 0.2779, 0.6868, 0.4437]], np.float16) | |||
| net = GatherDGradNet(dim) | |||
| output = net(index, grad) | |||
| error = 1e-4 | |||
| diff = output.asnumpy() - expect | |||
| assert np.all(diff < error) | |||
| @pytest.mark.level0 | |||
| @pytest.mark.platform_x86_gpu_training | |||
| @pytest.mark.env_onecard | |||
| def test_gather_grad_pynative_int32_fp32(): | |||
| context.set_context(mode=context.PYNATIVE_MODE, device_target="GPU") | |||
| dim = 0 | |||
| index = Tensor(np.array([[0, 1, 1, 0, 0], [1, 0, 0, 1, 1]]), ms.int32) | |||
| grad = Tensor(np.array([[0.9031, 0.0890, 0.2779, 0.3198, 0.5710], | |||
| [0.6949, 0.8439, 0.2003, 0.6868, 0.4437]]), ms.float32) | |||
| expect = np.array([[0.9031, 0.8439, 0.2003, 0.3198, 0.5710], | |||
| [0.6949, 0.0890, 0.2779, 0.6868, 0.4437]], np.float32) | |||
| output = P.GatherDGrad(dim)(index, grad) | |||
| error = 1e-4 | |||
| diff = output.asnumpy() - expect | |||
| assert np.all(diff < error) | |||
| @pytest.mark.level0 | |||
| @pytest.mark.platform_x86_gpu_training | |||
| @pytest.mark.env_onecard | |||
| def test_gather_grad_pynative_int64_fp32(): | |||
| context.set_context(mode=context.PYNATIVE_MODE, device_target="GPU") | |||
| dim = 0 | |||
| index = Tensor(np.array([[0, 1, 1, 0, 0], [1, 0, 0, 1, 1]]), ms.int64) | |||
| grad = Tensor(np.array([[0.9031, 0.0890, 0.2779, 0.3198, 0.5710], | |||
| [0.6949, 0.8439, 0.2003, 0.6868, 0.4437]]), ms.float32) | |||
| expect = np.array([[0.9031, 0.8439, 0.2003, 0.3198, 0.5710], | |||
| [0.6949, 0.0890, 0.2779, 0.6868, 0.4437]], np.float32) | |||
| output = P.GatherDGrad(dim)(index, grad) | |||
| error = 1e-4 | |||
| diff = output.asnumpy() - expect | |||
| assert np.all(diff < error) | |||
| @pytest.mark.level0 | |||
| @pytest.mark.platform_x86_gpu_training | |||
| @pytest.mark.env_onecard | |||
| def test_gather_grad_pynative_int32_fp16(): | |||
| context.set_context(mode=context.PYNATIVE_MODE, device_target="GPU") | |||
| dim = 0 | |||
| index = Tensor(np.array([[0, 1, 1, 0, 0], [1, 0, 0, 1, 1]]), ms.int32) | |||
| grad = Tensor(np.array([[0.9031, 0.0890, 0.2779, 0.3198, 0.5710], | |||
| [0.6949, 0.8439, 0.2003, 0.6868, 0.4437]]), ms.float16) | |||
| expect = np.array([[0.9031, 0.8439, 0.2003, 0.3198, 0.5710], | |||
| [0.6949, 0.0890, 0.2779, 0.6868, 0.4437]], np.float16) | |||
| output = P.GatherDGrad(dim)(index, grad) | |||
| error = 1e-4 | |||
| diff = output.asnumpy() - expect | |||
| assert np.all(diff < error) | |||
| @pytest.mark.level0 | |||
| @pytest.mark.platform_x86_gpu_training | |||
| @pytest.mark.env_onecard | |||
| def test_gather_grad_pynative_int64_fp16(): | |||
| context.set_context(mode=context.PYNATIVE_MODE, device_target="GPU") | |||
| dim = 0 | |||
| index = Tensor(np.array([[0, 1, 1, 0, 0], [1, 0, 0, 1, 1]]), ms.int64) | |||
| grad = Tensor(np.array([[0.9031, 0.0890, 0.2779, 0.3198, 0.5710], | |||
| [0.6949, 0.8439, 0.2003, 0.6868, 0.4437]]), ms.float16) | |||
| expect = np.array([[0.9031, 0.8439, 0.2003, 0.3198, 0.5710], | |||
| [0.6949, 0.0890, 0.2779, 0.6868, 0.4437]], np.float16) | |||
| output = P.GatherDGrad(dim)(index, grad) | |||
| error = 1e-4 | |||
| diff = output.asnumpy() - expect | |||
| assert np.all(diff < error) | |||