Browse Source

roi end mode

tags/v1.1.0
jonwe 5 years ago
parent
commit
575280bb61
5 changed files with 82 additions and 217 deletions
  1. +10
    -5
      mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/roi_align_impl.cu
  2. +0
    -78
      tests/st/ops/gpu/test_roi_align_grad_half_op.py
  3. +28
    -31
      tests/st/ops/gpu/test_roi_align_grad_op.py
  4. +0
    -49
      tests/st/ops/gpu/test_roi_align_half_op.py
  5. +44
    -54
      tests/st/ops/gpu/test_roi_align_op.py

+ 10
- 5
mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/roi_align_impl.cu View File

@@ -91,16 +91,21 @@ __device__ void bin_box(int thread_idx, const T *roi_boxes, int roi_cols, const
}

// Scale and shift ROI
T roi_offset = roi_end_mode == 0 ? static_cast<T>(0.5) : static_cast<T>(.0);
*roi_start_w = roi_box[0] * spatial_scale - roi_offset;
*roi_start_h = roi_box[1] * spatial_scale - roi_offset;
T roi_end_w = roi_box[2] * spatial_scale - roi_offset;
T roi_end_h = roi_box[3] * spatial_scale - roi_offset;
*roi_start_w = roi_box[0] * spatial_scale;
*roi_start_h = roi_box[1] * spatial_scale;
T roi_end_w = (roi_box[2] + static_cast<T>(roi_end_mode)) * spatial_scale;
T roi_end_h = (roi_box[3] + static_cast<T>(roi_end_mode)) * spatial_scale;

// New ROI height/width
T roi_width = roi_end_w - (*roi_start_w);
T roi_height = roi_end_h - (*roi_start_h);

if (roi_end_mode == 0) { // backward compatibility
// Force malformed ROIs to be 1x1
roi_width = roi_width > static_cast<T>(1.0) ? roi_width : static_cast<T>(1.0);
roi_height = roi_height > static_cast<T>(1.0) ? roi_height : static_cast<T>(1.0);
}

// ratio of roi / pooled
*bin_size_h = static_cast<T>(roi_height) / static_cast<T>(pooled_height);
*bin_size_w = static_cast<T>(roi_width) / static_cast<T>(pooled_width);


+ 0
- 78
tests/st/ops/gpu/test_roi_align_grad_half_op.py View File

@@ -1,78 +0,0 @@
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================

import numpy as np
import pytest

import mindspore.context as context
import mindspore.nn as nn
from mindspore import Tensor
from mindspore.ops.operations import _grad_ops as G

context.set_context(mode=context.GRAPH_MODE, device_target="GPU")


class NetROIAlignGrad(nn.Cell):
def __init__(self, xdiff_shape, pooled_height, pooled_width, spatial_scale, sample_num):
super(NetROIAlignGrad, self).__init__()
self.roiAlignGrad = G.ROIAlignGrad(
xdiff_shape,
pooled_height,
pooled_width,
spatial_scale,
sample_num)

def construct(self, dy, rois):
return self.roiAlignGrad(dy, rois)


@pytest.mark.level0
@pytest.mark.platform_x86_gpu_training
@pytest.mark.env_onecard
def test_roi_align_grad_half():
rois = Tensor(np.array([[0, -2.0, -2.0, 22.0, 22.0]], np.float16))

dy = Tensor(np.array([[[
[.1, .2, .3],
[.1, .2, .3],
[.1, .2, .3]
]]], np.float16))

xdiff_shape = (1, 1, 6, 6)
pooled_height, pooled_width, spatial_scale, sample_num = 3, 3, 0.25, 2

context.set_context(mode=context.GRAPH_MODE, device_target="GPU")
roi_align_grad = NetROIAlignGrad(
xdiff_shape,
pooled_height,
pooled_width,
spatial_scale,
sample_num)
output = roi_align_grad(dy, rois)
print(output)
# the out if aligned is True
# expect = ([[[[0.0563, 0.0563, 0.0750, 0.0938, 0.1125, 0.0563],
# [0.0375, 0.0375, 0.0500, 0.0625, 0.0750, 0.0375],
# [0.0375, 0.0375, 0.0500, 0.0625, 0.0750, 0.0375],
# [0.0375, 0.0375, 0.0500, 0.0625, 0.0750, 0.0375],
# [0.0375, 0.0375, 0.0500, 0.0625, 0.0750, 0.0375],
# [0.0188, 0.0188, 0.0250, 0.0312, 0.0375, 0.0188]]]])
expect = ([[[[0.025, 0.025, 0.05, 0.05, 0.075, 0.075],
[0.025, 0.025, 0.05, 0.05, 0.075, 0.075],
[0.025, 0.025, 0.05, 0.05, 0.075, 0.075],
[0.025, 0.025, 0.05, 0.05, 0.075, 0.075],
[0.025, 0.025, 0.05, 0.05, 0.075, 0.075],
[0.025, 0.025, 0.05, 0.05, 0.075, 0.075]]]])
np.testing.assert_almost_equal(output.asnumpy(), expect, decimal=4)

+ 28
- 31
tests/st/ops/gpu/test_roi_align_grad_op.py View File

@@ -42,37 +42,34 @@ class NetROIAlignGrad(nn.Cell):
@pytest.mark.platform_x86_gpu_training
@pytest.mark.env_onecard
def test_roi_align_grad():
rois = Tensor(np.array([[0, -2.0, -2.0, 22.0, 22.0]], np.float32))
def roi_align_grad_case(data_type):
rois = Tensor(np.array([[0, -2.0, -2.0, 21.0, 21.0]], data_type))

dy = Tensor(np.array([[[
[.1, .2, .3],
[.1, .2, .3],
[.1, .2, .3]
]]], np.float32))
dy = Tensor(np.array([[[
[.1, .2, .3],
[.1, .2, .3],
[.1, .2, .3]
]]], data_type))

xdiff_shape = (1, 1, 6, 6)
pooled_height, pooled_width, spatial_scale, sample_num = 3, 3, 0.25, 2
xdiff_shape = (1, 1, 6, 6)
pooled_height, pooled_width, spatial_scale, sample_num = 3, 3, 0.25, 2
context.set_context(mode=context.GRAPH_MODE, device_target="GPU")

context.set_context(mode=context.GRAPH_MODE, device_target="GPU")
roi_align_grad = NetROIAlignGrad(
xdiff_shape,
pooled_height,
pooled_width,
spatial_scale,
sample_num)
output = roi_align_grad(dy, rois)
print(output)
# the out if aligned is True
# expect = ([[[[0.0563, 0.0563, 0.0750, 0.0938, 0.1125, 0.0563],
# [0.0375, 0.0375, 0.0500, 0.0625, 0.0750, 0.0375],
# [0.0375, 0.0375, 0.0500, 0.0625, 0.0750, 0.0375],
# [0.0375, 0.0375, 0.0500, 0.0625, 0.0750, 0.0375],
# [0.0375, 0.0375, 0.0500, 0.0625, 0.0750, 0.0375],
# [0.0188, 0.0188, 0.0250, 0.0312, 0.0375, 0.0188]]]])
expect = ([[[[0.025, 0.025, 0.05, 0.05, 0.075, 0.075],
[0.025, 0.025, 0.05, 0.05, 0.075, 0.075],
[0.025, 0.025, 0.05, 0.05, 0.075, 0.075],
[0.025, 0.025, 0.05, 0.05, 0.075, 0.075],
[0.025, 0.025, 0.05, 0.05, 0.075, 0.075],
[0.025, 0.025, 0.05, 0.05, 0.075, 0.075]]]])
np.testing.assert_almost_equal(output.asnumpy(), expect, decimal=4)
roi_align_grad = NetROIAlignGrad(
xdiff_shape,
pooled_height,
pooled_width,
spatial_scale,
sample_num)
output = roi_align_grad(dy, rois)
print(output)
expect = ([[[[0.025, 0.025, 0.05, 0.05, 0.075, 0.075],
[0.025, 0.025, 0.05, 0.05, 0.075, 0.075],
[0.025, 0.025, 0.05, 0.05, 0.075, 0.075],
[0.025, 0.025, 0.05, 0.05, 0.075, 0.075],
[0.025, 0.025, 0.05, 0.05, 0.075, 0.075],
[0.025, 0.025, 0.05, 0.05, 0.075, 0.075]]]])
np.testing.assert_almost_equal(output.asnumpy(), expect, decimal=4)

roi_align_grad_case(np.float32)
roi_align_grad_case(np.float16)

+ 0
- 49
tests/st/ops/gpu/test_roi_align_half_op.py View File

@@ -1,49 +0,0 @@
# Copyright 2019 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================

import numpy as np
import pytest

import mindspore.context as context
from mindspore import Tensor
from mindspore.ops import operations as P


@pytest.mark.level0
@pytest.mark.platform_x86_gpu_training
@pytest.mark.env_onecard
def test_roi_align_half():
context.set_context(mode=context.PYNATIVE_MODE, device_target="GPU")
x = Tensor(np.array([[
[[1, 2, 3, 4, 5, 6],
[7, 8, 9, 10, 11, 12],
[13, 14, 15, 16, 17, 18],
[19, 20, 21, 22, 23, 24],
[25, 26, 27, 28, 29, 30],
[31, 32, 33, 34, 35, 36]]
]], np.float16))

rois = Tensor(np.array([[0, -2.0, -2.0, 22.0, 22.0]], np.float16))

# test case 1
pooled_height, pooled_width, spatial_scale, sample_num = 4, 4, 0.2, 3
roi_align = P.ROIAlign(pooled_height, pooled_width, spatial_scale, sample_num, 0)
output = roi_align(x, rois)
print(output)
expect = [[[[1.2333, 2.1000, 3.3000, 4.5000],
[6.4333, 7.3000, 8.5000, 9.7000],
[13.6333, 14.5000, 15.7000, 16.9000],
[20.8333, 21.7000, 22.9000, 24.1000]]]]
np.testing.assert_almost_equal(output.asnumpy(), expect, decimal=1)

+ 44
- 54
tests/st/ops/gpu/test_roi_align_op.py View File

@@ -25,61 +25,51 @@ from mindspore.ops import operations as P
@pytest.mark.platform_x86_gpu_training
@pytest.mark.env_onecard
def test_roi_align():
context.set_context(mode=context.PYNATIVE_MODE, device_target="GPU")
x = Tensor(np.array([[
[[1, 2, 3, 4, 5, 6],
[7, 8, 9, 10, 11, 12],
[13, 14, 15, 16, 17, 18],
[19, 20, 21, 22, 23, 24],
[25, 26, 27, 28, 29, 30],
[31, 32, 33, 34, 35, 36]]
]], np.float32))
def roi_align_case(data_type):
context.set_context(mode=context.PYNATIVE_MODE, device_target="GPU")
x = Tensor(np.array([[
[[1, 2, 3, 4, 5, 6],
[7, 8, 9, 10, 11, 12],
[13, 14, 15, 16, 17, 18],
[19, 20, 21, 22, 23, 24],
[25, 26, 27, 28, 29, 30],
[31, 32, 33, 34, 35, 36]]
]], data_type))

rois = Tensor(np.array([[0, -2.0, -2.0, 22.0, 22.0]], np.float32))
# test case 1
rois = Tensor(np.array([[0, -2.0, -2.0, 21.0, 21.0]], data_type))
pooled_height, pooled_width, spatial_scale, sample_num = 3, 3, 0.25, 2
roi_align = P.ROIAlign(pooled_height, pooled_width,
spatial_scale, sample_num, 1)
output = roi_align(x, rois)
print(output)
expect = [[[[4.5, 6.5, 8.5],
[16.5, 18.5, 20.5],
[28.5, 30.5, 32.5]]]]
assert (output.asnumpy() == expect).all()

# test case 1
pooled_height, pooled_width, spatial_scale, sample_num = 3, 3, 0.25, 2
roi_align = P.ROIAlign(pooled_height, pooled_width, spatial_scale, sample_num, 0)
output = roi_align(x, rois)
print(output)
expect = [[[[2.75, 4.5, 6.5],
[13.25, 15., 17.],
[25.25, 27., 29.]]]]
assert (output.asnumpy() == expect).all()
# test case 2
rois = Tensor(np.array([[0, -2.0, -2.0, 22.0, 22.0]], data_type))
pooled_height, pooled_width, spatial_scale, sample_num = 3, 3, 0.25, 2
roi_align = P.ROIAlign(pooled_height, pooled_width,
spatial_scale, sample_num, 0)
output = roi_align(x, rois)
print(output)
expect = [[[[4.5, 6.5, 8.5],
[16.5, 18.5, 20.5],
[28.5, 30.5, 32.5]]]]
assert (output.asnumpy() == expect).all()

# test case 2
pooled_height, pooled_width, spatial_scale, sample_num = 4, 4, 0.2, 3
roi_align = P.ROIAlign(pooled_height, pooled_width, spatial_scale, sample_num, 0)
output = roi_align(x, rois)
print(output)
expect = [[[[1.2333, 2.1000, 3.3000, 4.5000],
[6.4333, 7.3000, 8.5000, 9.7000],
[13.6333, 14.5000, 15.7000, 16.9000],
[20.8333, 21.7000, 22.9000, 24.1000]]]]
np.testing.assert_almost_equal(output.asnumpy(), expect, decimal=4)
# test case 3
pooled_height, pooled_width, spatial_scale, sample_num = 2, 2, 1.0, -1
rois = Tensor(np.array([[0, -2.0, -2.0, 22.0, 22.0]], data_type))
roi_align = P.ROIAlign(pooled_height, pooled_width,
spatial_scale, sample_num, 0)
output = roi_align(x, rois)
print(output)
expect = [[[[6.295, 0.],
[0., 0.]]]]
np.testing.assert_almost_equal(output.asnumpy(), expect, decimal=2)

# test case 3
pooled_height, pooled_width, spatial_scale, sample_num = 3, 3, 0.3, 3
rois = Tensor(np.array([[0, -2.0, -2.0, 22.0, 22.0],
[0, 1.0, 0.0, 19.0, 18.0]],
np.float32))
roi_align = P.ROIAlign(pooled_height, pooled_width, spatial_scale, sample_num, 0)
output = roi_align(x, rois)
print(output)
expect = [[[[3.3333, 5.5000, 7.6667],
[16.3333, 18.5000, 20.6667],
[29.3333, 31.5000, 33.6667]]],
[[[4.5000, 6.3000, 8.1000],
[14.9000, 16.7000, 18.5000],
[25.7000, 27.5000, 29.3000]]]]
np.testing.assert_almost_equal(output.asnumpy(), expect, decimal=4)

# test case 4
pooled_height, pooled_width, spatial_scale, sample_num = 2, 2, 1.0, -1
rois = Tensor(np.array([[0, -2.0, -2.0, 22.0, 22.0]], np.float32))
roi_align = P.ROIAlign(pooled_height, pooled_width, spatial_scale, sample_num, 0)
output = roi_align(x, rois)
print(output)
expect = [[[[8.2222, 0.],
[0., 0.]]]]
np.testing.assert_almost_equal(output.asnumpy(), expect, decimal=4)
roi_align_case(np.float32)
roi_align_case(np.float16)

Loading…
Cancel
Save