Merge pull request !3201 from JonathanY/maintags/v0.6.0-beta
| @@ -0,0 +1,228 @@ | |||
| /** | |||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #include "roi_align_impl.cuh" | |||
| #include "runtime/device/gpu/cuda_common.h" | |||
| template <typename T> | |||
| inline __device__ T gpu_atomic_add(const T val, T *address); | |||
| template <> | |||
| inline __device__ float gpu_atomic_add(const float val, float *address) { | |||
| return atomicAdd(address, val); | |||
| } | |||
| template <typename T> | |||
| __device__ void bilinear_interpolate(const int height, const int width, T y, T x, int *x_low, int *y_low, int *x_high, | |||
| int *y_high, T *w1, T *w2, T *w3, T *w4) { | |||
| // return 0 if out of map boundary | |||
| if (y <= static_cast<T>(-1.0) || y >= static_cast<T>(height) || x <= static_cast<T>(-1.0) || | |||
| x >= static_cast<T>(width)) { | |||
| *w1 = *w2 = *w3 = *w4 = 0; | |||
| *x_low = *x_high = *y_low = *y_high = -1; | |||
| return; | |||
| } | |||
| // low bounder is at least zero | |||
| y = y <= static_cast<T>(.0) ? static_cast<T>(.0) : y; | |||
| x = x <= static_cast<T>(.0) ? static_cast<T>(.0) : x; | |||
| // top left point | |||
| *y_low = static_cast<int>(y); | |||
| *x_low = static_cast<int>(x); | |||
| // bottom right point | |||
| if (*y_low >= height - 1) { | |||
| *y_high = *y_low = height - 1; | |||
| y = static_cast<T>(*y_low); | |||
| } else { | |||
| *y_high = *y_low + 1; | |||
| } | |||
| if (*x_low >= width - 1) { | |||
| *x_high = *x_low = width - 1; | |||
| x = static_cast<T>(*x_low); | |||
| } else { | |||
| *x_high = *x_low + 1; | |||
| } | |||
| // distance to nearest points | |||
| T lx, ly, hx, hy; | |||
| ly = y - static_cast<T>(*y_low), lx = x - static_cast<T>(*x_low); | |||
| hy = static_cast<T>(1.) - ly, hx = static_cast<T>(1.) - lx; | |||
| // weight is evaluated by the distance to point away. | |||
| // the closer to point home, the more weight, the farther to point away. | |||
| *w1 = hy * hx, *w2 = hy * lx, *w3 = ly * hx, *w4 = ly * lx; | |||
| return; | |||
| } | |||
| template <typename T> | |||
| __device__ void bin_box(int thread_idx, const T *roi_boxes, int roi_cols, const T spatial_scale, const int sample_num, | |||
| int roi_end_mode, const int channels, const int height, const int width, | |||
| const int pooled_height, const int pooled_width, int *offset, int *n, int *c, int *ph, int *pw, | |||
| int *roi_bin_grid_h, int *roi_bin_grid_w, T *bin_size_h, T *bin_size_w, T *roi_start_h, | |||
| T *roi_start_w) { | |||
| // (n, c, ph, pw) is the base param of pooled map | |||
| *pw = thread_idx % pooled_width; | |||
| *ph = (thread_idx / pooled_width) % pooled_height; | |||
| *c = (thread_idx / pooled_width / pooled_height) % channels; | |||
| *n = thread_idx / pooled_width / pooled_height / channels; | |||
| // Roi has | |||
| // 1. 4 points, or | |||
| // 2. indicator + 4 points (1 + 4) | |||
| const T *roi_box = roi_boxes + (*n) * roi_cols; | |||
| int roi_batch_ind = 0; | |||
| if (roi_cols == 5) { | |||
| roi_batch_ind = roi_box[0]; | |||
| roi_box++; | |||
| } | |||
| // Scale and shift ROI | |||
| T roi_offset = roi_end_mode == 1 ? static_cast<T>(0.5) : static_cast<T>(.0); | |||
| *roi_start_w = roi_box[0] * spatial_scale - roi_offset; | |||
| *roi_start_h = roi_box[1] * spatial_scale - roi_offset; | |||
| T roi_end_w = roi_box[2] * spatial_scale - roi_offset; | |||
| T roi_end_h = roi_box[3] * spatial_scale - roi_offset; | |||
| // New ROI height/width | |||
| T roi_width = roi_end_w - (*roi_start_w); | |||
| T roi_height = roi_end_h - (*roi_start_h); | |||
| // ratio of roi / pooled | |||
| *bin_size_h = static_cast<T>(roi_height) / static_cast<T>(pooled_height); | |||
| *bin_size_w = static_cast<T>(roi_width) / static_cast<T>(pooled_width); | |||
| *offset = (roi_batch_ind * channels + (*c)) * height * width; | |||
| // grid (int) by Sample ratio if defined, otherwise by pooled H/W | |||
| *roi_bin_grid_h = (sample_num > 0) ? sample_num : static_cast<int>(roi_height / static_cast<T>(pooled_height)); | |||
| *roi_bin_grid_w = (sample_num > 0) ? sample_num : static_cast<int>(roi_width / static_cast<T>(pooled_width)); | |||
| return; | |||
| } | |||
| template <typename T> | |||
| __global__ void ROIAlignKernel(size_t size, const T *input, const T *roi_boxes, int roi_cols, T *out_data, | |||
| const T spatial_scale, const int sample_num, int roi_end_mode, const int channels, | |||
| const int height, const int width, const int pooled_height, const int pooled_width) { | |||
| for (int thread_idx = blockIdx.x * blockDim.x + threadIdx.x; thread_idx < size; | |||
| thread_idx += blockDim.x * gridDim.x) { | |||
| int offset, n, c, ph, pw, roi_bin_grid_h, roi_bin_grid_w; | |||
| T bin_size_h, bin_size_w, roi_start_h, roi_start_w; | |||
| bin_box(thread_idx, roi_boxes, roi_cols, spatial_scale, sample_num, roi_end_mode, channels, height, width, | |||
| pooled_height, pooled_width, &offset, &n, &c, &ph, &pw, &roi_bin_grid_h, &roi_bin_grid_w, &bin_size_h, | |||
| &bin_size_w, &roi_start_h, &roi_start_w); | |||
| // (n, c, ph, pw) is the base param of pooled map | |||
| const T count_points_in_grid_cell = roi_bin_grid_h * roi_bin_grid_w; | |||
| T accumulate_val = 0.; | |||
| for (int iy = 0; iy < roi_bin_grid_h; iy++) { | |||
| // Shift half point RIGHT for y / x, while previous scaled roi shift half point LEFT | |||
| const T y = roi_start_h + static_cast<T>(ph) * bin_size_h + | |||
| static_cast<T>(iy + .5f) * bin_size_h / static_cast<T>(roi_bin_grid_h); | |||
| for (int ix = 0; ix < roi_bin_grid_w; ix++) { | |||
| const T x = roi_start_w + static_cast<T>(pw) * bin_size_w + | |||
| static_cast<T>(ix + .5f) * bin_size_w / static_cast<T>(roi_bin_grid_w); | |||
| // bilinear interpolate by shifted y / x | |||
| // calculate bilinear interpolation | |||
| int x_low, y_low, x_high, y_high; | |||
| T w1, w2, w3, w4; | |||
| bilinear_interpolate(height, width, y, x, &x_low, &y_low, &x_high, &y_high, &w1, &w2, &w3, &w4); | |||
| T v1 = input[y_low * width + x_low + offset]; | |||
| T v2 = input[y_low * width + x_high + offset]; | |||
| T v3 = input[y_high * width + x_low + offset]; | |||
| T v4 = input[y_high * width + x_high + offset]; | |||
| T val = (w1 * v1 + w2 * v2 + w3 * v3 + w4 * v4); | |||
| accumulate_val += val; | |||
| } | |||
| } | |||
| accumulate_val /= count_points_in_grid_cell; | |||
| out_data[thread_idx] = accumulate_val; | |||
| } | |||
| } | |||
| template <typename T> | |||
| void ROIAlign(const T *x, const T *roi_boxes, int roi_rows, int roi_cols, T *out_data, const T spatial_scale, | |||
| const int sample_num, int roi_end_mode, const int channels, const int height, const int width, | |||
| const int pooled_height, const int pooled_width, cudaStream_t cuda_stream) { | |||
| size_t size = roi_rows * channels * pooled_height * pooled_width; | |||
| ROIAlignKernel<<<GET_BLOCKS(size), GET_THREADS, 0, cuda_stream>>>(size, x, roi_boxes, roi_cols, out_data, | |||
| spatial_scale, sample_num, roi_end_mode, channels, | |||
| height, width, pooled_height, pooled_width); | |||
| return; | |||
| } | |||
| template void ROIAlign<float>(const float *x, const float *roi_boxes, int roi_rows, int roi_cols, float *out_data, | |||
| const float spatial_scale, const int sample_num, int roi_end_mode, const int channels, | |||
| const int height, const int width, const int pooled_height, const int pooled_width, | |||
| cudaStream_t cuda_stream); | |||
| template void ROIAlign<half>(const half *x, const half *roi_boxes, int roi_rows, int roi_cols, half *out_data, | |||
| const half spatial_scale, const int sample_num, int roi_end_mode, const int channels, | |||
| const int height, const int width, const int pooled_height, const int pooled_width, | |||
| cudaStream_t cuda_stream); | |||
| template <typename T> | |||
| __global__ void ROIAlignGradKernel(size_t size, const T *dy, const T *roi_boxes, int roi_cols, T *dx, | |||
| const T spatial_scale, const int sample_num, int roi_end_mode, const int channels, | |||
| const int height, const int width, const int pooled_height, const int pooled_width) { | |||
| for (int thread_idx = blockIdx.x * blockDim.x + threadIdx.x; thread_idx < size; | |||
| thread_idx += blockDim.x * gridDim.x) { | |||
| int offset, n, c, ph, pw, roi_bin_grid_h, roi_bin_grid_w; | |||
| T bin_size_h, bin_size_w, roi_start_h, roi_start_w; | |||
| bin_box(thread_idx, roi_boxes, roi_cols, spatial_scale, sample_num, roi_end_mode, channels, height, width, | |||
| pooled_height, pooled_width, &offset, &n, &c, &ph, &pw, &roi_bin_grid_h, &roi_bin_grid_w, &bin_size_h, | |||
| &bin_size_w, &roi_start_h, &roi_start_w); | |||
| // (n, c, ph, pw) is the base param of pooled map | |||
| const T count_points_in_grid_cell = roi_bin_grid_h * roi_bin_grid_w; | |||
| int top_offset = (n * channels + c) * pooled_height * pooled_width; | |||
| const T *offset_top_diff = dy + top_offset; | |||
| const T top_diff_this_bin = offset_top_diff[ph * pooled_width + pw]; | |||
| for (int iy = 0; iy < roi_bin_grid_h; iy++) { | |||
| // Shift half point RIGHT for y / x, while previous scaled roi shift half point LEFT | |||
| const T y = | |||
| roi_start_h + ph * bin_size_h + static_cast<T>(iy + .5f) * bin_size_h / static_cast<T>(roi_bin_grid_h); | |||
| for (int ix = 0; ix < roi_bin_grid_w; ix++) { | |||
| const T x = | |||
| roi_start_w + pw * bin_size_w + static_cast<T>(ix + .5f) * bin_size_w / static_cast<T>(roi_bin_grid_w); | |||
| // bilinear interpolate by shifted y / x | |||
| // calculate bilinear interpolation | |||
| int x_low, y_low, x_high, y_high; | |||
| T w1, w2, w3, w4; | |||
| bilinear_interpolate(height, width, y, x, &x_low, &y_low, &x_high, &y_high, &w1, &w2, &w3, &w4); | |||
| T g1 = top_diff_this_bin * w1 / count_points_in_grid_cell; | |||
| T g2 = top_diff_this_bin * w2 / count_points_in_grid_cell; | |||
| T g3 = top_diff_this_bin * w3 / count_points_in_grid_cell; | |||
| T g4 = top_diff_this_bin * w4 / count_points_in_grid_cell; | |||
| if (x_low >= 0 && x_high >= 0 && y_low >= 0 && y_high >= 0) { | |||
| gpu_atomic_add(static_cast<T>(g1), dx + offset + y_low * width + x_low); | |||
| gpu_atomic_add(static_cast<T>(g2), dx + offset + y_low * width + x_high); | |||
| gpu_atomic_add(static_cast<T>(g3), dx + offset + y_high * width + x_low); | |||
| gpu_atomic_add(static_cast<T>(g4), dx + offset + y_high * width + x_high); | |||
| } | |||
| } | |||
| } | |||
| } | |||
| } | |||
| @@ -0,0 +1,24 @@ | |||
| /** | |||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #ifndef MINDSPORE_CCSRC_KERNEL_GPU_CUDA_IMPL_ROI_ALIGN_IMPL_H_ | |||
| #define MINDSPORE_CCSRC_KERNEL_GPU_CUDA_IMPL_ROI_ALIGN_IMPL_H_ | |||
| template <typename T> | |||
| void ROIAlign(const T *x, const T *roi_boxes, int roi_rows, int roi_cols, T *out_data, const T spatial_scale, | |||
| const int sample_num, int roi_end_mode, const int channels, const int height, const int width, | |||
| const int pooled_height, const int pooled_width, cudaStream_t cuda_stream); | |||
| #endif // MINDSPORE_CCSRC_KERNEL_GPU_CUDA_IMPL_ROI_ALIGN_IMPL_H_ | |||
| @@ -0,0 +1,32 @@ | |||
| /** | |||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #include "backend/kernel_compiler/gpu/nn/roi_align_gpu_kernel.h" | |||
| namespace mindspore { | |||
| namespace kernel { | |||
| MS_REG_GPU_KERNEL_ONE( | |||
| ROIAlign, | |||
| KernelAttr().AddInputAttr(kNumberTypeFloat32).AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32), | |||
| ROIAlignGpuFwdKernel, float) | |||
| MS_REG_GPU_KERNEL_ONE( | |||
| ROIAlign, | |||
| KernelAttr().AddInputAttr(kNumberTypeFloat16).AddInputAttr(kNumberTypeFloat16).AddOutputAttr(kNumberTypeFloat16), | |||
| ROIAlignGpuFwdKernel, half) | |||
| } // namespace kernel | |||
| } // namespace mindspore | |||
| @@ -0,0 +1,140 @@ | |||
| /** | |||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #ifndef MINDSPORE_CCSRC_KERNEL_GPU_ROI_ALIGN_GPU_KERNEL_H | |||
| #define MINDSPORE_CCSRC_KERNEL_GPU_ROI_ALIGN_GPU_KERNEL_H | |||
| #include <vector> | |||
| #include "backend/kernel_compiler/gpu/gpu_kernel.h" | |||
| #include "backend/kernel_compiler/gpu/gpu_kernel_factory.h" | |||
| #include "backend/kernel_compiler/gpu/cuda_impl/roi_align_impl.cuh" | |||
| namespace mindspore { | |||
| namespace kernel { | |||
| template <typename T> | |||
| class ROIAlignGpuFwdKernel : public GpuKernel { | |||
| public: | |||
| ROIAlignGpuFwdKernel() : x_size_(0), rois_size_(0), output_size_(0) {} | |||
| ~ROIAlignGpuFwdKernel() = default; | |||
| const std::vector<size_t> &GetInputSizeList() const override { return input_size_list_; } | |||
| const std::vector<size_t> &GetOutputSizeList() const override { return output_size_list_; } | |||
| const std::vector<size_t> &GetWorkspaceSizeList() const override { return workspace_size_list_; } | |||
| bool Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &workspace, | |||
| const std::vector<AddressPtr> &outputs, void *stream_ptr) override { | |||
| const T *x = GetDeviceAddress<T>(inputs, 0); | |||
| const T *rois = GetDeviceAddress<T>(inputs, 1); | |||
| T *out_data = GetDeviceAddress<T>(outputs, 0); | |||
| ROIAlign(x, rois, roi_rows_, roi_cols_, out_data, spatial_scale_, sample_num_, roi_end_mode_, channels_, height_, | |||
| width_, pooled_height_, pooled_width_, reinterpret_cast<cudaStream_t>(stream_ptr)); | |||
| return true; | |||
| } | |||
| bool Init(const CNodePtr &kernel_node) override { | |||
| // Get the number of input args | |||
| size_t input_num = AnfAlgo::GetInputTensorNum(kernel_node); | |||
| if (input_num != 2) { | |||
| MS_LOG(ERROR) << "Input number is " << input_num << ", but RioAlign needs 2 input."; | |||
| return false; | |||
| } | |||
| // Get the number of output args | |||
| size_t output_num = AnfAlgo::GetOutputTensorNum(kernel_node); | |||
| if (output_num != 1) { | |||
| MS_LOG(ERROR) << "Output number is " << output_num << ", but RioAlign needs 1 output."; | |||
| return false; | |||
| } | |||
| // Get the input shapes | |||
| auto x_shape = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 0); | |||
| auto rois_shape = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 1); | |||
| auto x_shape_size = x_shape.size(); | |||
| if (x_shape_size < 2) { | |||
| MS_LOG(ERROR) << "x shape szie is " << x_shape_size << ", but at lease 2D."; | |||
| return false; | |||
| } | |||
| // Get channels, height & width | |||
| channels_ = x_shape_size >= 3 ? x_shape[x_shape_size - 3] : 1; | |||
| height_ = x_shape[x_shape_size - 2]; | |||
| width_ = x_shape[x_shape_size - 1]; | |||
| x_shape_ = {channels_, height_, width_}; | |||
| x_size_ = channels_ * height_ * width_ * sizeof(T); | |||
| // Get rois rows and cols | |||
| roi_rows_ = rois_shape[0]; | |||
| roi_cols_ = rois_shape[1]; | |||
| rois_size_ = roi_rows_ * roi_cols_ * sizeof(T); | |||
| rois_shape_ = {roi_rows_, roi_cols_}; | |||
| // Get primitive args | |||
| pooled_height_ = GetAttr<int>(kernel_node, "pooled_height"); | |||
| pooled_width_ = GetAttr<int>(kernel_node, "pooled_width"); | |||
| spatial_scale_ = static_cast<T>(GetAttr<float>(kernel_node, "spatial_scale")); | |||
| sample_num_ = GetAttr<int>(kernel_node, "sample_num"); | |||
| roi_end_mode_ = GetAttr<int>(kernel_node, "roi_end_mode"); | |||
| // Get output_shape | |||
| output_shape_ = {roi_rows_, channels_, pooled_height_, pooled_width_}; | |||
| output_size_ = 1; | |||
| for (size_t i = 0; i < 4; i++) { | |||
| output_size_ *= output_shape_[i]; | |||
| } | |||
| output_size_ *= sizeof(T); | |||
| InitSizeLists(); | |||
| return true; | |||
| } | |||
| protected: | |||
| void InitSizeLists() override { | |||
| input_size_list_.push_back(x_size_); | |||
| input_size_list_.push_back(rois_size_); | |||
| output_size_list_.push_back(output_size_); | |||
| } | |||
| private: | |||
| int pooled_height_; | |||
| int pooled_width_; | |||
| T spatial_scale_; | |||
| int sample_num_; | |||
| int roi_end_mode_; | |||
| int roi_rows_; | |||
| int roi_cols_; | |||
| int channels_; | |||
| int height_; | |||
| int width_; | |||
| std::vector<size_t> input_size_list_; | |||
| std::vector<size_t> output_size_list_; | |||
| std::vector<size_t> workspace_size_list_; | |||
| std::vector<int> x_shape_; | |||
| std::vector<int> rois_shape_; | |||
| std::vector<int> output_shape_; | |||
| size_t x_size_; | |||
| size_t rois_size_; | |||
| size_t output_size_; | |||
| }; | |||
| } // namespace kernel | |||
| } // namespace mindspore | |||
| #endif // MINDSPORE_CCSRC_KERNEL_GPU_ROI_ALIGN_GPU_KERNEL_H | |||
| @@ -0,0 +1,95 @@ | |||
| # Copyright 2019 Huawei Technologies Co., Ltd | |||
| # | |||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||
| # you may not use this file except in compliance with the License. | |||
| # You may obtain a copy of the License at | |||
| # | |||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||
| # | |||
| # Unless required by applicable law or agreed to in writing, software | |||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| # See the License for the specific language governing permissions and | |||
| # limitations under the License. | |||
| # ============================================================================ | |||
| import numpy as np | |||
| import pytest | |||
| import mindspore.context as context | |||
| from mindspore import Tensor | |||
| from mindspore.ops import operations as P | |||
| @pytest.mark.level0 | |||
| @pytest.mark.platform_x86_gpu_training | |||
| @pytest.mark.env_onecard | |||
| def test_roi_align(): | |||
| context.set_context(mode=context.PYNATIVE_MODE, device_target="GPU") | |||
| x = Tensor(np.array([[ | |||
| [[1, 2, 3, 4, 5, 6], | |||
| [7, 8, 9, 10, 11, 12], | |||
| [13, 14, 15, 16, 17, 18], | |||
| [19, 20, 21, 22, 23, 24], | |||
| [25, 26, 27, 28, 29, 30], | |||
| [31, 32, 33, 34, 35, 36]] | |||
| ]], np.float32)) | |||
| rois = Tensor(np.array([[0, -2.0, -2.0, 22.0, 22.0]], np.float32)) | |||
| # test case 1 | |||
| pooled_height, pooled_width, spatial_scale, sample_num = 3, 3, 0.25, 2 | |||
| roi_align = P.ROIAlign(pooled_height, pooled_width, spatial_scale, sample_num) | |||
| output = roi_align(x, rois) | |||
| print(output) | |||
| expect = [[[[2.75, 4.5, 6.5], | |||
| [13.25, 15., 17.], | |||
| [25.25, 27., 29.]]]] | |||
| assert (output.asnumpy() == expect).all() | |||
| # test case 1 | |||
| pooled_height, pooled_width, spatial_scale, sample_num = 3, 3, 0.25, 2 | |||
| roi_align = P.ROIAlign(pooled_height, pooled_width, spatial_scale, sample_num) | |||
| output = roi_align(x, rois) | |||
| print(output) | |||
| expect = [[[[2.75, 4.5, 6.5], | |||
| [13.25, 15., 17.], | |||
| [25.25, 27., 29.]]]] | |||
| assert (output.asnumpy() == expect).all() | |||
| # test case 2 | |||
| pooled_height, pooled_width, spatial_scale, sample_num = 4, 4, 0.2, 3 | |||
| roi_align = P.ROIAlign(pooled_height, pooled_width, spatial_scale, sample_num) | |||
| output = roi_align(x, rois) | |||
| print(output) | |||
| expect = [[[[1.2333, 2.1000, 3.3000, 4.5000], | |||
| [6.4333, 7.3000, 8.5000, 9.7000], | |||
| [13.6333, 14.5000, 15.7000, 16.9000], | |||
| [20.8333, 21.7000, 22.9000, 24.1000]]]] | |||
| np.testing.assert_almost_equal(output.asnumpy(), expect, decimal=4) | |||
| # test case 3 | |||
| pooled_height, pooled_width, spatial_scale, sample_num = 3, 3, 0.3, 3 | |||
| rois = Tensor(np.array([[0, -2.0, -2.0, 22.0, 22.0], | |||
| [0, 1.0, 0.0, 19.0, 18.0]], | |||
| np.float32)) | |||
| roi_align = P.ROIAlign(pooled_height, pooled_width, spatial_scale, sample_num) | |||
| output = roi_align(x, rois) | |||
| print(output) | |||
| expect = [[[[3.3333, 5.5000, 7.6667], | |||
| [16.3333, 18.5000, 20.6667], | |||
| [29.3333, 31.5000, 33.6667]]], | |||
| [[[4.5000, 6.3000, 8.1000], | |||
| [14.9000, 16.7000, 18.5000], | |||
| [25.7000, 27.5000, 29.3000]]]] | |||
| np.testing.assert_almost_equal(output.asnumpy(), expect, decimal=4) | |||
| # test case 4 | |||
| pooled_height, pooled_width, spatial_scale, sample_num = 2, 2, 1.0, -1 | |||
| rois = Tensor(np.array([[0, -2.0, -2.0, 22.0, 22.0]], np.float32)) | |||
| roi_align = P.ROIAlign(pooled_height, pooled_width, spatial_scale, sample_num) | |||
| output = roi_align(x, rois) | |||
| print(output) | |||
| expect = [[[[4.625, 0.], | |||
| [0., 0.]]]] | |||
| np.testing.assert_almost_equal(output.asnumpy(), expect, decimal=4) | |||