|
|
|
@@ -27,7 +27,7 @@ namespace kernel { |
|
|
|
template <typename T, typename S> |
|
|
|
class GatherGradGpuKernel : public GpuKernel { |
|
|
|
public: |
|
|
|
GatherGradGpuKernel() : axis_(0), handle_(nullptr) {} |
|
|
|
GatherGradGpuKernel() : axis_(0) {} |
|
|
|
~GatherGradGpuKernel() = default; |
|
|
|
|
|
|
|
const std::vector<size_t> &GetInputSizeList() const override { return input_size_list_; } |
|
|
|
@@ -66,7 +66,7 @@ class GatherGradGpuKernel : public GpuKernel { |
|
|
|
} |
|
|
|
|
|
|
|
protected: |
|
|
|
void InitResource() override { handle_ = device::gpu::GPUDeviceManager::GetInstance().GetCudnnHandle(); } |
|
|
|
void InitResource() override {} |
|
|
|
void InitSizeLists() override { |
|
|
|
size_t size = GetSize(index_shapes_, true); |
|
|
|
input_size_list_.push_back(size); |
|
|
|
@@ -114,7 +114,6 @@ class GatherGradGpuKernel : public GpuKernel { |
|
|
|
|
|
|
|
size_t dims_[4] = {}; |
|
|
|
int axis_; |
|
|
|
cudnnHandle_t handle_; |
|
|
|
|
|
|
|
std::vector<size_t> input_size_list_; |
|
|
|
std::vector<size_t> output_size_list_; |
|
|
|
|