diff --git a/mindspore/ccsrc/backend/kernel_compiler/gpu/arrays/gather_gpu_kernel.h b/mindspore/ccsrc/backend/kernel_compiler/gpu/arrays/gather_gpu_kernel.h index 43baea453f..3f2c258d8e 100644 --- a/mindspore/ccsrc/backend/kernel_compiler/gpu/arrays/gather_gpu_kernel.h +++ b/mindspore/ccsrc/backend/kernel_compiler/gpu/arrays/gather_gpu_kernel.h @@ -27,7 +27,7 @@ namespace kernel { template class GatherGpuFwdKernel : public GpuKernel { public: - GatherGpuFwdKernel() : axis_(0), handle_(nullptr) {} + GatherGpuFwdKernel() : axis_(0) {} ~GatherGpuFwdKernel() = default; const std::vector &GetInputSizeList() const override { return input_size_list_; } @@ -65,7 +65,7 @@ class GatherGpuFwdKernel : public GpuKernel { } protected: - void InitResource() override { handle_ = device::gpu::GPUDeviceManager::GetInstance().GetCudnnHandle(); } + void InitResource() override {} void InitSizeLists() override { size_t size = GetSize(input_shapes_, true); input_size_list_.push_back(size); @@ -113,7 +113,6 @@ class GatherGpuFwdKernel : public GpuKernel { size_t dims_[4] = {}; int axis_; - cudnnHandle_t handle_; std::vector input_size_list_; std::vector output_size_list_; diff --git a/mindspore/ccsrc/backend/kernel_compiler/gpu/arrays/gather_grad_gpu_kernel.h b/mindspore/ccsrc/backend/kernel_compiler/gpu/arrays/gather_grad_gpu_kernel.h index 881551faf8..e47d4140c9 100644 --- a/mindspore/ccsrc/backend/kernel_compiler/gpu/arrays/gather_grad_gpu_kernel.h +++ b/mindspore/ccsrc/backend/kernel_compiler/gpu/arrays/gather_grad_gpu_kernel.h @@ -27,7 +27,7 @@ namespace kernel { template class GatherGradGpuKernel : public GpuKernel { public: - GatherGradGpuKernel() : axis_(0), handle_(nullptr) {} + GatherGradGpuKernel() : axis_(0) {} ~GatherGradGpuKernel() = default; const std::vector &GetInputSizeList() const override { return input_size_list_; } @@ -66,7 +66,7 @@ class GatherGradGpuKernel : public GpuKernel { } protected: - void InitResource() override { handle_ = device::gpu::GPUDeviceManager::GetInstance().GetCudnnHandle(); } + void InitResource() override {} void InitSizeLists() override { size_t size = GetSize(index_shapes_, true); input_size_list_.push_back(size); @@ -114,7 +114,6 @@ class GatherGradGpuKernel : public GpuKernel { size_t dims_[4] = {}; int axis_; - cudnnHandle_t handle_; std::vector input_size_list_; std::vector output_size_list_; diff --git a/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/gather.cu b/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/gather.cu index 31cb096b21..5feb28deb2 100755 --- a/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/gather.cu +++ b/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/gather.cu @@ -28,8 +28,12 @@ __global__ void GatherKernel(const T *input, const S *index, T *output, const si i = id / (dim_at_axis_output * dim_after_axis); k = id % dim_after_axis; - CUDA_KERNEL_ASSERT(index[id] >= 0); - size_t j_read = static_cast(index[id]); + S j = index[id]; + if (j < 0) { + j += static_cast(dim_at_axis_input); + } + CUDA_KERNEL_ASSERT(j >= 0); + size_t j_read = static_cast(j); CUDA_KERNEL_ASSERT(j_read < dim_at_axis_input); size_t read_id = i * dim_at_axis_input * dim_after_axis + j_read * dim_after_axis + k; output[id] = input[read_id]; diff --git a/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/gather_grad.cu b/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/gather_grad.cu index bd67da4690..7e09af2ad4 100755 --- a/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/gather_grad.cu +++ b/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/gather_grad.cu @@ -30,8 +30,12 @@ __global__ void GatherGradKernel(const size_t num, const T *index, const S *grad i = id / (dim_at_axis_index * dim_after_axis); k = id % dim_after_axis; - CUDA_KERNEL_ASSERT(index[id] >= 0); - size_t j_read = static_cast(index[id]); + T j = index[id]; + if (j < 0) { + j += static_cast(dim_at_axis_output); + } + CUDA_KERNEL_ASSERT(j >= 0); + size_t j_read = static_cast(j); CUDA_KERNEL_ASSERT(j_read < dim_at_axis_output); size_t read_id = i * dim_at_axis_output * dim_after_axis + j_read * dim_after_axis + k; MsAtomicAdd(output + read_id, grad[id]);