Browse Source

!6005 Optimize ROI Align kernel

Merge pull request !6005 from JonathanY/rcnn
tags/v1.0.0
mindspore-ci-bot Gitee 5 years ago
parent
commit
cd2646c98e
1 changed files with 19 additions and 2 deletions
  1. +19
    -2
      mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/roi_align_impl.cu

+ 19
- 2
mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/roi_align_impl.cu View File

@@ -116,7 +116,15 @@ __global__ void ROIAlignKernel(size_t size, const T *input, const T *roi_boxes,
const int height, const int width, const int pooled_height, const int pooled_width) { const int height, const int width, const int pooled_height, const int pooled_width) {
for (int thread_idx = blockIdx.x * blockDim.x + threadIdx.x; thread_idx < size; for (int thread_idx = blockIdx.x * blockDim.x + threadIdx.x; thread_idx < size;
thread_idx += blockDim.x * gridDim.x) { thread_idx += blockDim.x * gridDim.x) {
int offset, n, c, ph, pw, roi_bin_grid_h, roi_bin_grid_w;
int n = thread_idx / pooled_width / pooled_height / channels;
const T *roi_box = roi_boxes + n * roi_cols;
if (roi_box[0] < static_cast<T>(0.001) && roi_box[1] < static_cast<T>(0.001) &&
roi_box[2] < static_cast<T>(0.001) && roi_box[3] < static_cast<T>(0.001) &&
roi_box[0] > static_cast<T>(-0.001) && roi_box[1] > static_cast<T>(-0.001) &&
roi_box[2] > static_cast<T>(-0.001) && roi_box[3] > static_cast<T>(-0.001)) {
continue;
}
int offset, c, ph, pw, roi_bin_grid_h, roi_bin_grid_w;
T bin_size_h, bin_size_w, roi_start_h, roi_start_w; T bin_size_h, bin_size_w, roi_start_h, roi_start_w;


bin_box(thread_idx, roi_boxes, roi_cols, spatial_scale, sample_num, roi_end_mode, channels, height, width, bin_box(thread_idx, roi_boxes, roi_cols, spatial_scale, sample_num, roi_end_mode, channels, height, width,
@@ -183,7 +191,16 @@ __global__ void ROIAlignGradKernel(size_t size, const T *dy, const T *roi_boxes,
const int height, const int width, const int pooled_height, const int pooled_width) { const int height, const int width, const int pooled_height, const int pooled_width) {
for (int thread_idx = blockIdx.x * blockDim.x + threadIdx.x; thread_idx < size; for (int thread_idx = blockIdx.x * blockDim.x + threadIdx.x; thread_idx < size;
thread_idx += blockDim.x * gridDim.x) { thread_idx += blockDim.x * gridDim.x) {
int offset, n, c, ph, pw, roi_bin_grid_h, roi_bin_grid_w;
int n = thread_idx / pooled_width / pooled_height / channels;
const T *roi_box = roi_boxes + n * roi_cols;
if (roi_box[0] < static_cast<T>(0.001) && roi_box[1] < static_cast<T>(0.001) &&
roi_box[2] < static_cast<T>(0.001) && roi_box[3] < static_cast<T>(0.001) &&
roi_box[0] > static_cast<T>(-0.001) && roi_box[1] > static_cast<T>(-0.001) &&
roi_box[2] > static_cast<T>(-0.001) && roi_box[3] > static_cast<T>(-0.001)) {
continue;
}

int offset, c, ph, pw, roi_bin_grid_h, roi_bin_grid_w;
T bin_size_h, bin_size_w, roi_start_h, roi_start_w; T bin_size_h, bin_size_w, roi_start_h, roi_start_w;


bin_box(thread_idx, roi_boxes, roi_cols, spatial_scale, sample_num, roi_end_mode, channels, height, width, bin_box(thread_idx, roi_boxes, roi_cols, spatial_scale, sample_num, roi_end_mode, channels, height, width,


Loading…
Cancel
Save