| @@ -27,7 +27,7 @@ namespace kernel { | |||||
| template <typename T, typename S> | template <typename T, typename S> | ||||
| class GatherGpuFwdKernel : public GpuKernel { | class GatherGpuFwdKernel : public GpuKernel { | ||||
| public: | public: | ||||
| GatherGpuFwdKernel() : axis_(0), handle_(nullptr) {} | |||||
| GatherGpuFwdKernel() : axis_(0) {} | |||||
| ~GatherGpuFwdKernel() = default; | ~GatherGpuFwdKernel() = default; | ||||
| const std::vector<size_t> &GetInputSizeList() const override { return input_size_list_; } | const std::vector<size_t> &GetInputSizeList() const override { return input_size_list_; } | ||||
| @@ -65,7 +65,7 @@ class GatherGpuFwdKernel : public GpuKernel { | |||||
| } | } | ||||
| protected: | protected: | ||||
| void InitResource() override { handle_ = device::gpu::GPUDeviceManager::GetInstance().GetCudnnHandle(); } | |||||
| void InitResource() override {} | |||||
| void InitSizeLists() override { | void InitSizeLists() override { | ||||
| size_t size = GetSize(input_shapes_, true); | size_t size = GetSize(input_shapes_, true); | ||||
| input_size_list_.push_back(size); | input_size_list_.push_back(size); | ||||
| @@ -113,7 +113,6 @@ class GatherGpuFwdKernel : public GpuKernel { | |||||
| size_t dims_[4] = {}; | size_t dims_[4] = {}; | ||||
| int axis_; | int axis_; | ||||
| cudnnHandle_t handle_; | |||||
| std::vector<size_t> input_size_list_; | std::vector<size_t> input_size_list_; | ||||
| std::vector<size_t> output_size_list_; | std::vector<size_t> output_size_list_; | ||||
| @@ -27,7 +27,7 @@ namespace kernel { | |||||
| template <typename T, typename S> | template <typename T, typename S> | ||||
| class GatherGradGpuKernel : public GpuKernel { | class GatherGradGpuKernel : public GpuKernel { | ||||
| public: | public: | ||||
| GatherGradGpuKernel() : axis_(0), handle_(nullptr) {} | |||||
| GatherGradGpuKernel() : axis_(0) {} | |||||
| ~GatherGradGpuKernel() = default; | ~GatherGradGpuKernel() = default; | ||||
| const std::vector<size_t> &GetInputSizeList() const override { return input_size_list_; } | const std::vector<size_t> &GetInputSizeList() const override { return input_size_list_; } | ||||
| @@ -66,7 +66,7 @@ class GatherGradGpuKernel : public GpuKernel { | |||||
| } | } | ||||
| protected: | protected: | ||||
| void InitResource() override { handle_ = device::gpu::GPUDeviceManager::GetInstance().GetCudnnHandle(); } | |||||
| void InitResource() override {} | |||||
| void InitSizeLists() override { | void InitSizeLists() override { | ||||
| size_t size = GetSize(index_shapes_, true); | size_t size = GetSize(index_shapes_, true); | ||||
| input_size_list_.push_back(size); | input_size_list_.push_back(size); | ||||
| @@ -114,7 +114,6 @@ class GatherGradGpuKernel : public GpuKernel { | |||||
| size_t dims_[4] = {}; | size_t dims_[4] = {}; | ||||
| int axis_; | int axis_; | ||||
| cudnnHandle_t handle_; | |||||
| std::vector<size_t> input_size_list_; | std::vector<size_t> input_size_list_; | ||||
| std::vector<size_t> output_size_list_; | std::vector<size_t> output_size_list_; | ||||
| @@ -28,8 +28,12 @@ __global__ void GatherKernel(const T *input, const S *index, T *output, const si | |||||
| i = id / (dim_at_axis_output * dim_after_axis); | i = id / (dim_at_axis_output * dim_after_axis); | ||||
| k = id % dim_after_axis; | k = id % dim_after_axis; | ||||
| CUDA_KERNEL_ASSERT(index[id] >= 0); | |||||
| size_t j_read = static_cast<size_t>(index[id]); | |||||
| S j = index[id]; | |||||
| if (j < 0) { | |||||
| j += static_cast<S>(dim_at_axis_input); | |||||
| } | |||||
| CUDA_KERNEL_ASSERT(j >= 0); | |||||
| size_t j_read = static_cast<size_t>(j); | |||||
| CUDA_KERNEL_ASSERT(j_read < dim_at_axis_input); | CUDA_KERNEL_ASSERT(j_read < dim_at_axis_input); | ||||
| size_t read_id = i * dim_at_axis_input * dim_after_axis + j_read * dim_after_axis + k; | size_t read_id = i * dim_at_axis_input * dim_after_axis + j_read * dim_after_axis + k; | ||||
| output[id] = input[read_id]; | output[id] = input[read_id]; | ||||
| @@ -30,8 +30,12 @@ __global__ void GatherGradKernel(const size_t num, const T *index, const S *grad | |||||
| i = id / (dim_at_axis_index * dim_after_axis); | i = id / (dim_at_axis_index * dim_after_axis); | ||||
| k = id % dim_after_axis; | k = id % dim_after_axis; | ||||
| CUDA_KERNEL_ASSERT(index[id] >= 0); | |||||
| size_t j_read = static_cast<size_t>(index[id]); | |||||
| T j = index[id]; | |||||
| if (j < 0) { | |||||
| j += static_cast<T>(dim_at_axis_output); | |||||
| } | |||||
| CUDA_KERNEL_ASSERT(j >= 0); | |||||
| size_t j_read = static_cast<size_t>(j); | |||||
| CUDA_KERNEL_ASSERT(j_read < dim_at_axis_output); | CUDA_KERNEL_ASSERT(j_read < dim_at_axis_output); | ||||
| size_t read_id = i * dim_at_axis_output * dim_after_axis + j_read * dim_after_axis + k; | size_t read_id = i * dim_at_axis_output * dim_after_axis + j_read * dim_after_axis + k; | ||||
| MsAtomicAdd(output + read_id, grad[id]); | MsAtomicAdd(output + read_id, grad[id]); | ||||