diff --git a/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/roi_align_impl.cu b/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/roi_align_impl.cu index 477c227e34..6e876d2f64 100644 --- a/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/roi_align_impl.cu +++ b/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/roi_align_impl.cu @@ -116,7 +116,15 @@ __global__ void ROIAlignKernel(size_t size, const T *input, const T *roi_boxes, const int height, const int width, const int pooled_height, const int pooled_width) { for (int thread_idx = blockIdx.x * blockDim.x + threadIdx.x; thread_idx < size; thread_idx += blockDim.x * gridDim.x) { - int offset, n, c, ph, pw, roi_bin_grid_h, roi_bin_grid_w; + int n = thread_idx / pooled_width / pooled_height / channels; + const T *roi_box = roi_boxes + n * roi_cols; + if (roi_box[0] < static_cast(0.001) && roi_box[1] < static_cast(0.001) && + roi_box[2] < static_cast(0.001) && roi_box[3] < static_cast(0.001) && + roi_box[0] > static_cast(-0.001) && roi_box[1] > static_cast(-0.001) && + roi_box[2] > static_cast(-0.001) && roi_box[3] > static_cast(-0.001)) { + continue; + } + int offset, c, ph, pw, roi_bin_grid_h, roi_bin_grid_w; T bin_size_h, bin_size_w, roi_start_h, roi_start_w; bin_box(thread_idx, roi_boxes, roi_cols, spatial_scale, sample_num, roi_end_mode, channels, height, width, @@ -183,7 +191,16 @@ __global__ void ROIAlignGradKernel(size_t size, const T *dy, const T *roi_boxes, const int height, const int width, const int pooled_height, const int pooled_width) { for (int thread_idx = blockIdx.x * blockDim.x + threadIdx.x; thread_idx < size; thread_idx += blockDim.x * gridDim.x) { - int offset, n, c, ph, pw, roi_bin_grid_h, roi_bin_grid_w; + int n = thread_idx / pooled_width / pooled_height / channels; + const T *roi_box = roi_boxes + n * roi_cols; + if (roi_box[0] < static_cast(0.001) && roi_box[1] < static_cast(0.001) && + roi_box[2] < static_cast(0.001) && roi_box[3] < static_cast(0.001) && + roi_box[0] > static_cast(-0.001) && roi_box[1] > static_cast(-0.001) && + roi_box[2] > static_cast(-0.001) && roi_box[3] > static_cast(-0.001)) { + continue; + } + + int offset, c, ph, pw, roi_bin_grid_h, roi_bin_grid_w; T bin_size_h, bin_size_w, roi_start_h, roi_start_w; bin_box(thread_idx, roi_boxes, roi_cols, spatial_scale, sample_num, roi_end_mode, channels, height, width,