|
|
|
@@ -91,13 +91,14 @@ class ROIAlignGradGpuFwdKernel : public GpuKernel { |
|
|
|
roi_end_mode_ = 1; |
|
|
|
|
|
|
|
// Get channels, height & width |
|
|
|
batch_size_ = xdiff_shape_[0]; |
|
|
|
channels_ = xdiff_shape_[1]; |
|
|
|
height_ = xdiff_shape_[2]; |
|
|
|
width_ = xdiff_shape_[3]; |
|
|
|
|
|
|
|
// Get output_shape |
|
|
|
output_shape_ = {roi_rows_, channels_, height_, width_}; |
|
|
|
output_size_ = roi_rows_ * channels_ * height_ * width_ * sizeof(T); |
|
|
|
output_shape_ = {batch_size_, channels_, height_, width_}; |
|
|
|
output_size_ = batch_size_ * channels_ * height_ * width_ * sizeof(T); |
|
|
|
|
|
|
|
InitSizeLists(); |
|
|
|
return true; |
|
|
|
@@ -120,6 +121,7 @@ class ROIAlignGradGpuFwdKernel : public GpuKernel { |
|
|
|
|
|
|
|
int roi_rows_; |
|
|
|
int roi_cols_; |
|
|
|
int batch_size_; |
|
|
|
int channels_; |
|
|
|
int height_; |
|
|
|
int width_; |
|
|
|
|