Browse Source

roi align memory leak

tags/v1.0.0
Jonathan Yan 5 years ago
parent
commit
bbd19dbe43
3 changed files with 34 additions and 30 deletions
  1. +29
    -26
      mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/roi_align_impl.cu
  2. +4
    -3
      mindspore/ccsrc/backend/kernel_compiler/gpu/nn/roi_align_gpu_kernel.h
  3. +1
    -1
      tests/st/ops/gpu/test_roi_align_op.py

+ 29
- 26
mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/roi_align_impl.cu View File

@@ -18,15 +18,14 @@
#include "util.cuh"
#include "runtime/device/gpu/cuda_common.h"

inline __device__ int roi_cast_int(float x) { return static_cast<int>(x); }
inline __device__ int roi_cast_int(float x) { return __float2int_rd(x); }
inline __device__ int roi_cast_int(half x) { return __half2int_rd(x); }

template <typename T>
__device__ void bilinear_interpolate(const int height, const int width, T y, T x, int *x_low, int *y_low, int *x_high,
int *y_high, T *w1, T *w2, T *w3, T *w4) {
// return 0 if out of map boundary
if (y <= static_cast<T>(-1.0) || y >= static_cast<T>(height) || x <= static_cast<T>(-1.0) ||
x >= static_cast<T>(width)) {
if (y < static_cast<T>(-1.0) || y > static_cast<T>(height) || x < static_cast<T>(-1.0) || x > static_cast<T>(width)) {
*w1 = *w2 = *w3 = *w4 = 0;
*x_low = *x_high = *y_low = *y_high = -1;
return;
@@ -137,16 +136,18 @@ __global__ void ROIAlignKernel(size_t size, const T *input, const T *roi_boxes,
static_cast<T>(ix + .5f) * bin_size_w / static_cast<T>(roi_bin_grid_w);
// bilinear interpolate by shifted y / x
// calculate bilinear interpolation
int x_low, y_low, x_high, y_high;
int x_low = 0, y_low = 0, x_high = 0, y_high = 0;
T w1, w2, w3, w4;
bilinear_interpolate(height, width, y, x, &x_low, &y_low, &x_high, &y_high, &w1, &w2, &w3, &w4);
T v1 = input[y_low * width + x_low + offset];
T v2 = input[y_low * width + x_high + offset];
T v3 = input[y_high * width + x_low + offset];
T v4 = input[y_high * width + x_high + offset];

T val = (w1 * v1 + w2 * v2 + w3 * v3 + w4 * v4);
accumulate_val += val;
if (x_low != -1 || x_high != -1 || y_low != -1 || y_high != -1) {
T v1 = input[offset + y_low * width + x_low];
T v2 = input[offset + y_low * width + x_high];
T v3 = input[offset + y_high * width + x_low];
T v4 = input[offset + y_high * width + x_high];

T val = (w1 * v1 + w2 * v2 + w3 * v3 + w4 * v4);
accumulate_val += val;
}
}
}
accumulate_val /= count_points_in_grid_cell;
@@ -205,23 +206,25 @@ __global__ void ROIAlignGradKernel(size_t size, const T *dy, const T *roi_boxes,
static_cast<T>(ix + .5f) * bin_size_w / static_cast<T>(roi_bin_grid_w);
// bilinear interpolate by shifted y / x
// calculate bilinear interpolation
int x_low, y_low, x_high, y_high;
int x_low = 0, y_low = 0, x_high = 0, y_high = 0;
T w1, w2, w3, w4;
bilinear_interpolate(height, width, y, x, &x_low, &y_low, &x_high, &y_high, &w1, &w2, &w3, &w4);
T g1 = top_diff_this_bin * w1 / count_points_in_grid_cell;
T g2 = top_diff_this_bin * w2 / count_points_in_grid_cell;
T g3 = top_diff_this_bin * w3 / count_points_in_grid_cell;
T g4 = top_diff_this_bin * w4 / count_points_in_grid_cell;

T *dx_1 = dx + offset + y_low * width + x_low;
T *dx_2 = dx + offset + y_low * width + x_high;
T *dx_3 = dx + offset + y_high * width + x_low;
T *dx_4 = dx + offset + y_high * width + x_high;
if (x_low >= 0 && x_high >= 0 && y_low >= 0 && y_high >= 0) {
MsAtomicAdd(dx_1, g1);
MsAtomicAdd(dx_2, g2);
MsAtomicAdd(dx_3, g3);
MsAtomicAdd(dx_4, g4);
if (x_low != -1 || x_high != -1 || y_low != -1 || y_high != -1) {
T g1 = top_diff_this_bin * w1 / count_points_in_grid_cell;
T g2 = top_diff_this_bin * w2 / count_points_in_grid_cell;
T g3 = top_diff_this_bin * w3 / count_points_in_grid_cell;
T g4 = top_diff_this_bin * w4 / count_points_in_grid_cell;

T *dx_1 = dx + offset + y_low * width + x_low;
T *dx_2 = dx + offset + y_low * width + x_high;
T *dx_3 = dx + offset + y_high * width + x_low;
T *dx_4 = dx + offset + y_high * width + x_high;
if (x_low >= 0 && x_high >= 0 && y_low >= 0 && y_high >= 0) {
MsAtomicAdd(dx_1, g1);
MsAtomicAdd(dx_2, g2);
MsAtomicAdd(dx_3, g3);
MsAtomicAdd(dx_4, g4);
}
}
}
}


+ 4
- 3
mindspore/ccsrc/backend/kernel_compiler/gpu/nn/roi_align_gpu_kernel.h View File

@@ -71,12 +71,12 @@ class ROIAlignGpuFwdKernel : public GpuKernel {
}

// Get channels, height & width
int batch_N = x_shape[0];
batch_N_ = x_shape[0];
channels_ = x_shape[1];
height_ = x_shape[2];
width_ = x_shape[3];
x_shape_ = {batch_N, channels_, height_, width_};
x_size_ = batch_N * channels_ * height_ * width_ * sizeof(T);
x_shape_ = {batch_N_, channels_, height_, width_};
x_size_ = batch_N_ * channels_ * height_ * width_ * sizeof(T);

// Get rois rows and cols
roi_rows_ = rois_shape[0];
@@ -119,6 +119,7 @@ class ROIAlignGpuFwdKernel : public GpuKernel {

int roi_rows_;
int roi_cols_;
int batch_N_;
int channels_;
int height_;
int width_;


+ 1
- 1
tests/st/ops/gpu/test_roi_align_op.py View File

@@ -80,6 +80,6 @@ def test_roi_align():
roi_align = P.ROIAlign(pooled_height, pooled_width, spatial_scale, sample_num)
output = roi_align(x, rois)
print(output)
expect = [[[[4.625, 0.],
expect = [[[[8.2222, 0.],
[0., 0.]]]]
np.testing.assert_almost_equal(output.asnumpy(), expect, decimal=4)

Loading…
Cancel
Save