From 575280bb61fc0e57f3fe84b3afe71b6dcf11b59f Mon Sep 17 00:00:00 2001 From: jonwe Date: Wed, 4 Nov 2020 13:21:10 -0500 Subject: [PATCH] roi end mode --- .../gpu/cuda_impl/roi_align_impl.cu | 15 ++- .../st/ops/gpu/test_roi_align_grad_half_op.py | 78 --------------- tests/st/ops/gpu/test_roi_align_grad_op.py | 59 ++++++----- tests/st/ops/gpu/test_roi_align_half_op.py | 49 ---------- tests/st/ops/gpu/test_roi_align_op.py | 98 +++++++++---------- 5 files changed, 82 insertions(+), 217 deletions(-) delete mode 100644 tests/st/ops/gpu/test_roi_align_grad_half_op.py delete mode 100644 tests/st/ops/gpu/test_roi_align_half_op.py diff --git a/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/roi_align_impl.cu b/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/roi_align_impl.cu index 96158b49c7..f46b0c1986 100644 --- a/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/roi_align_impl.cu +++ b/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/roi_align_impl.cu @@ -91,16 +91,21 @@ __device__ void bin_box(int thread_idx, const T *roi_boxes, int roi_cols, const } // Scale and shift ROI - T roi_offset = roi_end_mode == 0 ? static_cast(0.5) : static_cast(.0); - *roi_start_w = roi_box[0] * spatial_scale - roi_offset; - *roi_start_h = roi_box[1] * spatial_scale - roi_offset; - T roi_end_w = roi_box[2] * spatial_scale - roi_offset; - T roi_end_h = roi_box[3] * spatial_scale - roi_offset; + *roi_start_w = roi_box[0] * spatial_scale; + *roi_start_h = roi_box[1] * spatial_scale; + T roi_end_w = (roi_box[2] + static_cast(roi_end_mode)) * spatial_scale; + T roi_end_h = (roi_box[3] + static_cast(roi_end_mode)) * spatial_scale; // New ROI height/width T roi_width = roi_end_w - (*roi_start_w); T roi_height = roi_end_h - (*roi_start_h); + if (roi_end_mode == 0) { // backward compatibility + // Force malformed ROIs to be 1x1 + roi_width = roi_width > static_cast(1.0) ? roi_width : static_cast(1.0); + roi_height = roi_height > static_cast(1.0) ? roi_height : static_cast(1.0); + } + // ratio of roi / pooled *bin_size_h = static_cast(roi_height) / static_cast(pooled_height); *bin_size_w = static_cast(roi_width) / static_cast(pooled_width); diff --git a/tests/st/ops/gpu/test_roi_align_grad_half_op.py b/tests/st/ops/gpu/test_roi_align_grad_half_op.py deleted file mode 100644 index 95e48f53a7..0000000000 --- a/tests/st/ops/gpu/test_roi_align_grad_half_op.py +++ /dev/null @@ -1,78 +0,0 @@ -# Copyright 2020 Huawei Technologies Co., Ltd -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================ - -import numpy as np -import pytest - -import mindspore.context as context -import mindspore.nn as nn -from mindspore import Tensor -from mindspore.ops.operations import _grad_ops as G - -context.set_context(mode=context.GRAPH_MODE, device_target="GPU") - - -class NetROIAlignGrad(nn.Cell): - def __init__(self, xdiff_shape, pooled_height, pooled_width, spatial_scale, sample_num): - super(NetROIAlignGrad, self).__init__() - self.roiAlignGrad = G.ROIAlignGrad( - xdiff_shape, - pooled_height, - pooled_width, - spatial_scale, - sample_num) - - def construct(self, dy, rois): - return self.roiAlignGrad(dy, rois) - - -@pytest.mark.level0 -@pytest.mark.platform_x86_gpu_training -@pytest.mark.env_onecard -def test_roi_align_grad_half(): - rois = Tensor(np.array([[0, -2.0, -2.0, 22.0, 22.0]], np.float16)) - - dy = Tensor(np.array([[[ - [.1, .2, .3], - [.1, .2, .3], - [.1, .2, .3] - ]]], np.float16)) - - xdiff_shape = (1, 1, 6, 6) - pooled_height, pooled_width, spatial_scale, sample_num = 3, 3, 0.25, 2 - - context.set_context(mode=context.GRAPH_MODE, device_target="GPU") - roi_align_grad = NetROIAlignGrad( - xdiff_shape, - pooled_height, - pooled_width, - spatial_scale, - sample_num) - output = roi_align_grad(dy, rois) - print(output) - # the out if aligned is True - # expect = ([[[[0.0563, 0.0563, 0.0750, 0.0938, 0.1125, 0.0563], - # [0.0375, 0.0375, 0.0500, 0.0625, 0.0750, 0.0375], - # [0.0375, 0.0375, 0.0500, 0.0625, 0.0750, 0.0375], - # [0.0375, 0.0375, 0.0500, 0.0625, 0.0750, 0.0375], - # [0.0375, 0.0375, 0.0500, 0.0625, 0.0750, 0.0375], - # [0.0188, 0.0188, 0.0250, 0.0312, 0.0375, 0.0188]]]]) - expect = ([[[[0.025, 0.025, 0.05, 0.05, 0.075, 0.075], - [0.025, 0.025, 0.05, 0.05, 0.075, 0.075], - [0.025, 0.025, 0.05, 0.05, 0.075, 0.075], - [0.025, 0.025, 0.05, 0.05, 0.075, 0.075], - [0.025, 0.025, 0.05, 0.05, 0.075, 0.075], - [0.025, 0.025, 0.05, 0.05, 0.075, 0.075]]]]) - np.testing.assert_almost_equal(output.asnumpy(), expect, decimal=4) diff --git a/tests/st/ops/gpu/test_roi_align_grad_op.py b/tests/st/ops/gpu/test_roi_align_grad_op.py index 0b834a67ad..ddb4eb6591 100644 --- a/tests/st/ops/gpu/test_roi_align_grad_op.py +++ b/tests/st/ops/gpu/test_roi_align_grad_op.py @@ -42,37 +42,34 @@ class NetROIAlignGrad(nn.Cell): @pytest.mark.platform_x86_gpu_training @pytest.mark.env_onecard def test_roi_align_grad(): - rois = Tensor(np.array([[0, -2.0, -2.0, 22.0, 22.0]], np.float32)) + def roi_align_grad_case(data_type): + rois = Tensor(np.array([[0, -2.0, -2.0, 21.0, 21.0]], data_type)) - dy = Tensor(np.array([[[ - [.1, .2, .3], - [.1, .2, .3], - [.1, .2, .3] - ]]], np.float32)) + dy = Tensor(np.array([[[ + [.1, .2, .3], + [.1, .2, .3], + [.1, .2, .3] + ]]], data_type)) - xdiff_shape = (1, 1, 6, 6) - pooled_height, pooled_width, spatial_scale, sample_num = 3, 3, 0.25, 2 + xdiff_shape = (1, 1, 6, 6) + pooled_height, pooled_width, spatial_scale, sample_num = 3, 3, 0.25, 2 + context.set_context(mode=context.GRAPH_MODE, device_target="GPU") - context.set_context(mode=context.GRAPH_MODE, device_target="GPU") - roi_align_grad = NetROIAlignGrad( - xdiff_shape, - pooled_height, - pooled_width, - spatial_scale, - sample_num) - output = roi_align_grad(dy, rois) - print(output) - # the out if aligned is True - # expect = ([[[[0.0563, 0.0563, 0.0750, 0.0938, 0.1125, 0.0563], - # [0.0375, 0.0375, 0.0500, 0.0625, 0.0750, 0.0375], - # [0.0375, 0.0375, 0.0500, 0.0625, 0.0750, 0.0375], - # [0.0375, 0.0375, 0.0500, 0.0625, 0.0750, 0.0375], - # [0.0375, 0.0375, 0.0500, 0.0625, 0.0750, 0.0375], - # [0.0188, 0.0188, 0.0250, 0.0312, 0.0375, 0.0188]]]]) - expect = ([[[[0.025, 0.025, 0.05, 0.05, 0.075, 0.075], - [0.025, 0.025, 0.05, 0.05, 0.075, 0.075], - [0.025, 0.025, 0.05, 0.05, 0.075, 0.075], - [0.025, 0.025, 0.05, 0.05, 0.075, 0.075], - [0.025, 0.025, 0.05, 0.05, 0.075, 0.075], - [0.025, 0.025, 0.05, 0.05, 0.075, 0.075]]]]) - np.testing.assert_almost_equal(output.asnumpy(), expect, decimal=4) + roi_align_grad = NetROIAlignGrad( + xdiff_shape, + pooled_height, + pooled_width, + spatial_scale, + sample_num) + output = roi_align_grad(dy, rois) + print(output) + expect = ([[[[0.025, 0.025, 0.05, 0.05, 0.075, 0.075], + [0.025, 0.025, 0.05, 0.05, 0.075, 0.075], + [0.025, 0.025, 0.05, 0.05, 0.075, 0.075], + [0.025, 0.025, 0.05, 0.05, 0.075, 0.075], + [0.025, 0.025, 0.05, 0.05, 0.075, 0.075], + [0.025, 0.025, 0.05, 0.05, 0.075, 0.075]]]]) + np.testing.assert_almost_equal(output.asnumpy(), expect, decimal=4) + + roi_align_grad_case(np.float32) + roi_align_grad_case(np.float16) diff --git a/tests/st/ops/gpu/test_roi_align_half_op.py b/tests/st/ops/gpu/test_roi_align_half_op.py deleted file mode 100644 index 2d5b89c38b..0000000000 --- a/tests/st/ops/gpu/test_roi_align_half_op.py +++ /dev/null @@ -1,49 +0,0 @@ -# Copyright 2019 Huawei Technologies Co., Ltd -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================ - -import numpy as np -import pytest - -import mindspore.context as context -from mindspore import Tensor -from mindspore.ops import operations as P - - -@pytest.mark.level0 -@pytest.mark.platform_x86_gpu_training -@pytest.mark.env_onecard -def test_roi_align_half(): - context.set_context(mode=context.PYNATIVE_MODE, device_target="GPU") - x = Tensor(np.array([[ - [[1, 2, 3, 4, 5, 6], - [7, 8, 9, 10, 11, 12], - [13, 14, 15, 16, 17, 18], - [19, 20, 21, 22, 23, 24], - [25, 26, 27, 28, 29, 30], - [31, 32, 33, 34, 35, 36]] - ]], np.float16)) - - rois = Tensor(np.array([[0, -2.0, -2.0, 22.0, 22.0]], np.float16)) - - # test case 1 - pooled_height, pooled_width, spatial_scale, sample_num = 4, 4, 0.2, 3 - roi_align = P.ROIAlign(pooled_height, pooled_width, spatial_scale, sample_num, 0) - output = roi_align(x, rois) - print(output) - expect = [[[[1.2333, 2.1000, 3.3000, 4.5000], - [6.4333, 7.3000, 8.5000, 9.7000], - [13.6333, 14.5000, 15.7000, 16.9000], - [20.8333, 21.7000, 22.9000, 24.1000]]]] - np.testing.assert_almost_equal(output.asnumpy(), expect, decimal=1) diff --git a/tests/st/ops/gpu/test_roi_align_op.py b/tests/st/ops/gpu/test_roi_align_op.py index 3c5dce22bf..e8900c68f4 100644 --- a/tests/st/ops/gpu/test_roi_align_op.py +++ b/tests/st/ops/gpu/test_roi_align_op.py @@ -25,61 +25,51 @@ from mindspore.ops import operations as P @pytest.mark.platform_x86_gpu_training @pytest.mark.env_onecard def test_roi_align(): - context.set_context(mode=context.PYNATIVE_MODE, device_target="GPU") - x = Tensor(np.array([[ - [[1, 2, 3, 4, 5, 6], - [7, 8, 9, 10, 11, 12], - [13, 14, 15, 16, 17, 18], - [19, 20, 21, 22, 23, 24], - [25, 26, 27, 28, 29, 30], - [31, 32, 33, 34, 35, 36]] - ]], np.float32)) + def roi_align_case(data_type): + context.set_context(mode=context.PYNATIVE_MODE, device_target="GPU") + x = Tensor(np.array([[ + [[1, 2, 3, 4, 5, 6], + [7, 8, 9, 10, 11, 12], + [13, 14, 15, 16, 17, 18], + [19, 20, 21, 22, 23, 24], + [25, 26, 27, 28, 29, 30], + [31, 32, 33, 34, 35, 36]] + ]], data_type)) - rois = Tensor(np.array([[0, -2.0, -2.0, 22.0, 22.0]], np.float32)) + # test case 1 + rois = Tensor(np.array([[0, -2.0, -2.0, 21.0, 21.0]], data_type)) + pooled_height, pooled_width, spatial_scale, sample_num = 3, 3, 0.25, 2 + roi_align = P.ROIAlign(pooled_height, pooled_width, + spatial_scale, sample_num, 1) + output = roi_align(x, rois) + print(output) + expect = [[[[4.5, 6.5, 8.5], + [16.5, 18.5, 20.5], + [28.5, 30.5, 32.5]]]] + assert (output.asnumpy() == expect).all() - # test case 1 - pooled_height, pooled_width, spatial_scale, sample_num = 3, 3, 0.25, 2 - roi_align = P.ROIAlign(pooled_height, pooled_width, spatial_scale, sample_num, 0) - output = roi_align(x, rois) - print(output) - expect = [[[[2.75, 4.5, 6.5], - [13.25, 15., 17.], - [25.25, 27., 29.]]]] - assert (output.asnumpy() == expect).all() + # test case 2 + rois = Tensor(np.array([[0, -2.0, -2.0, 22.0, 22.0]], data_type)) + pooled_height, pooled_width, spatial_scale, sample_num = 3, 3, 0.25, 2 + roi_align = P.ROIAlign(pooled_height, pooled_width, + spatial_scale, sample_num, 0) + output = roi_align(x, rois) + print(output) + expect = [[[[4.5, 6.5, 8.5], + [16.5, 18.5, 20.5], + [28.5, 30.5, 32.5]]]] + assert (output.asnumpy() == expect).all() - # test case 2 - pooled_height, pooled_width, spatial_scale, sample_num = 4, 4, 0.2, 3 - roi_align = P.ROIAlign(pooled_height, pooled_width, spatial_scale, sample_num, 0) - output = roi_align(x, rois) - print(output) - expect = [[[[1.2333, 2.1000, 3.3000, 4.5000], - [6.4333, 7.3000, 8.5000, 9.7000], - [13.6333, 14.5000, 15.7000, 16.9000], - [20.8333, 21.7000, 22.9000, 24.1000]]]] - np.testing.assert_almost_equal(output.asnumpy(), expect, decimal=4) + # test case 3 + pooled_height, pooled_width, spatial_scale, sample_num = 2, 2, 1.0, -1 + rois = Tensor(np.array([[0, -2.0, -2.0, 22.0, 22.0]], data_type)) + roi_align = P.ROIAlign(pooled_height, pooled_width, + spatial_scale, sample_num, 0) + output = roi_align(x, rois) + print(output) + expect = [[[[6.295, 0.], + [0., 0.]]]] + np.testing.assert_almost_equal(output.asnumpy(), expect, decimal=2) - # test case 3 - pooled_height, pooled_width, spatial_scale, sample_num = 3, 3, 0.3, 3 - rois = Tensor(np.array([[0, -2.0, -2.0, 22.0, 22.0], - [0, 1.0, 0.0, 19.0, 18.0]], - np.float32)) - roi_align = P.ROIAlign(pooled_height, pooled_width, spatial_scale, sample_num, 0) - output = roi_align(x, rois) - print(output) - expect = [[[[3.3333, 5.5000, 7.6667], - [16.3333, 18.5000, 20.6667], - [29.3333, 31.5000, 33.6667]]], - [[[4.5000, 6.3000, 8.1000], - [14.9000, 16.7000, 18.5000], - [25.7000, 27.5000, 29.3000]]]] - np.testing.assert_almost_equal(output.asnumpy(), expect, decimal=4) - - # test case 4 - pooled_height, pooled_width, spatial_scale, sample_num = 2, 2, 1.0, -1 - rois = Tensor(np.array([[0, -2.0, -2.0, 22.0, 22.0]], np.float32)) - roi_align = P.ROIAlign(pooled_height, pooled_width, spatial_scale, sample_num, 0) - output = roi_align(x, rois) - print(output) - expect = [[[[8.2222, 0.], - [0., 0.]]]] - np.testing.assert_almost_equal(output.asnumpy(), expect, decimal=4) + roi_align_case(np.float32) + roi_align_case(np.float16)