Merge pull request !3619 from danishnxt/GPU_Onetags/v0.7.0-beta
| @@ -0,0 +1,193 @@ | |||
| /** | |||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, softwareg | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #include "nms_with_mask_impl.cuh" | |||
| #include <limits> | |||
| #include <algorithm> | |||
| int RoundUpPower2M(int v) { | |||
| v--; | |||
| v |= v >> 1; | |||
| v |= v >> 2; | |||
| v |= v >> 4; | |||
| v |= v >> 8; | |||
| v |= v >> 16; | |||
| v++; | |||
| return v; | |||
| } | |||
| template <typename T> | |||
| __inline__ __device__ void SwapM(T *lhs, T *rhs) { | |||
| T tmp = lhs[0]; | |||
| lhs[0] = rhs[0]; | |||
| rhs[0] = tmp; | |||
| } | |||
| template <typename T> | |||
| __inline__ __device__ bool IOUDecision(T *output, int box_A_ix, int box_B_ix, int box_A_start, int box_B_start, T *area, | |||
| float IOU_value) { | |||
| T x_1 = max(output[box_A_start + 0], output[box_B_start + 0]); | |||
| T y_1 = max(output[box_A_start + 1], output[box_B_start + 1]); | |||
| T x_2 = min(output[box_A_start + 2], output[box_B_start + 2]); | |||
| T y_2 = min(output[box_A_start + 3], output[box_B_start + 3]); | |||
| T width = max(x_2 - x_1, T(0)); // in case of no overlap | |||
| T height = max(y_2 - y_1, T(0)); | |||
| T combined_area = area[box_A_ix] + area[box_B_ix]; | |||
| // return decision to keep or remove box | |||
| return !(((width * height) / (combined_area - (width * height))) > IOU_value); | |||
| } | |||
| template <typename T> | |||
| __global__ void Preprocess(const int num, int *sel_idx, T *area, T *output, int box_size_) { | |||
| for (int box_num = blockIdx.x * blockDim.x + threadIdx.x; box_num < num; box_num += blockDim.x * gridDim.x) { | |||
| sel_idx[box_num] = box_num; | |||
| area[box_num] = (output[(box_num * box_size_) + 2] - output[(box_num * box_size_) + 0]) * | |||
| (output[(box_num * box_size_) + 3] - output[(box_num * box_size_) + 1]); | |||
| } | |||
| } | |||
| template <typename T> | |||
| __global__ void NMSWithMaskKernel(const int num, const float IOU_value, T *output, T *area, bool *sel_boxes, | |||
| int box_size_) { | |||
| for (int box_num = blockIdx.x * blockDim.x + threadIdx.x; box_num < num; box_num += blockDim.x * gridDim.x) { | |||
| // represents highest score box in that GPU block | |||
| if (threadIdx.x == 0) { | |||
| sel_boxes[box_num] = true; | |||
| continue; | |||
| } | |||
| int box_start_index = box_num * box_size_; // start index adjustment | |||
| int block_max_box_num = ((blockIdx.x * blockDim.x) + 0); | |||
| int block_max_box_start_index = block_max_box_num * box_size_; // start index adjustment | |||
| sel_boxes[box_num] = | |||
| IOUDecision(output, box_num, block_max_box_num, block_max_box_start_index, box_start_index, area, | |||
| IOU_value); // update mask | |||
| } | |||
| } | |||
| template <typename T> | |||
| __global__ void FinalPass(const int num, const float IOU_value, T *output, T *area, bool *sel_boxes, int box_size_) { | |||
| int box_i, box_j; // access all shared mem meta data with these | |||
| int box_i_start_index, box_j_start_index; // actual input data indexing | |||
| for (int i = 0; i < num - 1; i++) { | |||
| box_i = i; | |||
| box_i_start_index = box_i * box_size_; // adjust starting index | |||
| if (sel_boxes[box_i]) { | |||
| for (int j = i + 1; j < num; j++) { | |||
| box_j = j; | |||
| box_j_start_index = box_j * box_size_; | |||
| if (sel_boxes[box_j]) { | |||
| sel_boxes[box_j] = IOUDecision(output, box_i, box_j, box_i_start_index, box_j_start_index, area, IOU_value); | |||
| } | |||
| } | |||
| } | |||
| } | |||
| } | |||
| template <typename T, typename S> | |||
| __global__ void BitonicSortByKeyKernelM(const int outer, const int inner, const int ceil_power2, S *data_in, | |||
| S *data_out, T *index_buff, S *data_buff, int box_size_) { | |||
| // default: sort with share memory | |||
| extern __shared__ T share_mem_NMS[]; | |||
| T *index_arr = share_mem_NMS; | |||
| S *data_arr = reinterpret_cast<S *>(index_arr + ceil_power2); | |||
| // sort with RAM | |||
| if (index_buff != nullptr && data_buff != nullptr) { | |||
| index_arr = index_buff + blockIdx.x * ceil_power2; | |||
| data_arr = data_buff + blockIdx.x * ceil_power2; | |||
| } | |||
| for (int i = threadIdx.x; i < ceil_power2; i += blockDim.x) { | |||
| index_arr[i] = (i < inner) ? T(i) : std::numeric_limits<T>::max(); | |||
| // populated directly from input data | |||
| data_arr[i] = (i < inner) ? data_in[(blockIdx.x * inner + i) * box_size_ + 4] : std::numeric_limits<S>::max(); | |||
| } | |||
| __syncthreads(); | |||
| for (size_t i = 2; i <= ceil_power2; i <<= 1) { | |||
| for (size_t j = (i >> 1); j > 0; j >>= 1) { | |||
| for (size_t tid = threadIdx.x; tid < ceil_power2; tid += blockDim.x) { | |||
| size_t tid_comp = tid ^ j; | |||
| if (tid_comp > tid) { | |||
| if ((tid & i) == 0) { | |||
| if (data_arr[tid] > data_arr[tid_comp]) { | |||
| SwapM(&index_arr[tid], &index_arr[tid_comp]); | |||
| SwapM(&data_arr[tid], &data_arr[tid_comp]); | |||
| } | |||
| } else { | |||
| if (data_arr[tid] < data_arr[tid_comp]) { | |||
| SwapM(&index_arr[tid], &index_arr[tid_comp]); | |||
| SwapM(&data_arr[tid], &data_arr[tid_comp]); | |||
| } | |||
| } | |||
| } | |||
| } | |||
| __syncthreads(); | |||
| } | |||
| } | |||
| T correct_index; | |||
| for (size_t tid = threadIdx.x; tid < inner; tid += blockDim.x) { | |||
| correct_index = index_arr[(inner - 1) - tid]; | |||
| // moved data from input to output, correct ordering using sorted index array | |||
| for (auto i : {0, 1, 2, 3, 4}) { | |||
| data_out[(blockIdx.x * inner + tid) * box_size_ + i] = | |||
| data_in[(blockIdx.x * inner + correct_index) * box_size_ + i]; | |||
| } | |||
| } | |||
| } | |||
| template <typename T> | |||
| void CalPreprocess(const int num, int *sel_idx, T *area, T *output, int box_size_, cudaStream_t cuda_stream) { | |||
| Preprocess<<<GET_BLOCKS(num), GET_THREADS, 0, cuda_stream>>>(num, sel_idx, area, output, box_size_); | |||
| } | |||
| template <typename T, typename S> | |||
| void BitonicSortByKeyM(const int &outer, const int &inner, S *data_in, S *data_out, T *index_buff, S *data_buff, | |||
| int box_size_, cudaStream_t stream) { | |||
| int ceil_power2 = RoundUpPower2M(inner); | |||
| size_t share_mem = ceil_power2 * (sizeof(T) + sizeof(S)); | |||
| if (share_mem > SHARED_MEM_PER_BLOCK) { | |||
| share_mem = 0; | |||
| } else { | |||
| data_buff = nullptr; | |||
| index_buff = nullptr; | |||
| } | |||
| int thread = std::min(ceil_power2, GET_THREADS); | |||
| BitonicSortByKeyKernelM<<<outer, thread, share_mem, stream>>>(outer, inner, ceil_power2, data_in, data_out, | |||
| index_buff, data_buff, box_size_); | |||
| } | |||
| template <typename T> | |||
| void CalNMSWithMask(const int num, const float IOU_value, T *output, T *area, bool *sel_boxes, int box_size_, | |||
| cudaStream_t cuda_stream) { | |||
| NMSWithMaskKernel<<<GET_BLOCKS(num), GET_THREADS, 0, cuda_stream>>>(num, IOU_value, output, area, sel_boxes, | |||
| box_size_); | |||
| } | |||
| template <typename T> | |||
| void CalFinalPass(const int num, const float IOU_value, T *output, T *area, bool *sel_boxes, int box_size_, | |||
| cudaStream_t cuda_stream) { | |||
| FinalPass<<<1, 1, 0, cuda_stream>>>(num, IOU_value, output, area, sel_boxes, box_size_); | |||
| } | |||
| template void CalPreprocess<float>(const int num, int *sel_idx, float *area, float *output, int box_size_, | |||
| cudaStream_t cuda_stream); | |||
| template void BitonicSortByKeyM(const int &outer, const int &inner, float *data_in, float *data_out, int *index_buff, | |||
| float *data_buff, int box_size_, cudaStream_t stream); | |||
| template void CalNMSWithMask<float>(const int num, const float IOU_value, float *output, float *area, bool *sel_boxes, | |||
| int box_size_, cudaStream_t cuda_stream); | |||
| template void CalFinalPass<float>(const int num, const float IOU_value, float *output, float *area, bool *sel_boxes, | |||
| int box_size_, cudaStream_t cuda_stream); | |||
| @@ -0,0 +1,37 @@ | |||
| /** | |||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #ifndef MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_CUDA_IMPL_NMS_WITH_MASK_IMPL_H_ | |||
| #define MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_CUDA_IMPL_NMS_WITH_MASK_IMPL_H_ | |||
| #include "runtime/device/gpu/cuda_common.h" | |||
| template <typename T> | |||
| void CalPreprocess(const int num, int *sel_idx, T *area, T *output, int box_size_, cudaStream_t cuda_stream); | |||
| template <typename T> | |||
| void CalNMSWithMask(const int num, const float IOU_value, T *output, T *area, bool *sel_boxes, int box_size_, | |||
| cudaStream_t cuda_stream); | |||
| template <typename T, typename S> | |||
| void BitonicSortByKeyM(const int &outer, const int &inner, S *data_in, S *data_out, T *index_buff, S *data_buff, | |||
| int box_size_, cudaStream_t stream); | |||
| template <typename T> | |||
| void CalFinalPass(const int num, const float IOU_value, T *output, T *area, bool *sel_boxes, int box_size_, | |||
| cudaStream_t cuda_stream); | |||
| #endif // MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_CUDA_IMPL_NMS_WITH_MASK_IMPL_H_ | |||
| @@ -0,0 +1,29 @@ | |||
| /** | |||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #include "backend/kernel_compiler/gpu/math/nms_with_mask_gpu_kernel.h" | |||
| namespace mindspore { | |||
| namespace kernel { | |||
| MS_REG_GPU_KERNEL_ONE(NMSWithMask, | |||
| KernelAttr() | |||
| .AddInputAttr(kNumberTypeFloat32) | |||
| .AddOutputAttr(kNumberTypeFloat32) | |||
| .AddOutputAttr(kNumberTypeInt32) | |||
| .AddOutputAttr(kNumberTypeBool), | |||
| NMSWithMaskGpuFwdKernel, float) | |||
| } // namespace kernel | |||
| } // namespace mindspore | |||
| @@ -0,0 +1,121 @@ | |||
| /** | |||
| * Copyright 2019 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #ifndef MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_MATH_NMS_WITH_MASK_IMPL_H_ | |||
| #define MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_MATH_NMS_WITH_MASK_IMPL_H_ | |||
| #include <vector> | |||
| #include <memory> | |||
| #include <iostream> | |||
| #include "backend/kernel_compiler/gpu/gpu_kernel.h" | |||
| #include "backend/kernel_compiler/gpu/gpu_kernel_factory.h" | |||
| #include "backend/kernel_compiler/gpu/cuda_impl/nms_with_mask_impl.cuh" | |||
| #include "backend/kernel_compiler/gpu/kernel_constants.h" | |||
| namespace mindspore { | |||
| namespace kernel { | |||
| template <typename T> | |||
| class NMSWithMaskGpuFwdKernel : public GpuKernel { | |||
| public: | |||
| NMSWithMaskGpuFwdKernel() : num_input_(0), iou_value_(0.5), input_size_(0), output_size_(0), workspace_size_(0) {} | |||
| ~NMSWithMaskGpuFwdKernel() override = default; | |||
| const std::vector<size_t> &GetInputSizeList() const override { return input_size_list_; } | |||
| const std::vector<size_t> &GetOutputSizeList() const override { return output_size_list_; } | |||
| const std::vector<size_t> &GetWorkspaceSizeList() const override { return workspace_size_list_; } | |||
| bool Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &workspace, | |||
| const std::vector<AddressPtr> &outputs, void *stream_ptr) override { | |||
| T *input = GetDeviceAddress<T>(inputs, 0); | |||
| T *data_buff = GetDeviceAddress<T>(workspace, 0); // sort buffer | |||
| int *index_buff = GetDeviceAddress<int>(workspace, 1); | |||
| T *area = GetDeviceAddress<T>(workspace, 2); // store area values for all boxes | |||
| T *output = GetDeviceAddress<T>(outputs, 0); | |||
| int *sel_idx = GetDeviceAddress<int>(outputs, 1); | |||
| bool *sel_boxes = GetDeviceAddress<bool>(outputs, 2); | |||
| BitonicSortByKeyM(num_input_, num_input_, input, output, index_buff, data_buff, box_size_, | |||
| reinterpret_cast<cudaStream_t>(stream_ptr)); | |||
| CalPreprocess(num_input_, sel_idx, area, output, box_size_, reinterpret_cast<cudaStream_t>(stream_ptr)); | |||
| CalNMSWithMask(num_input_, iou_value_, output, area, sel_boxes, box_size_, | |||
| reinterpret_cast<cudaStream_t>(stream_ptr)); | |||
| CalFinalPass(num_input_, iou_value_, output, area, sel_boxes, box_size_, | |||
| reinterpret_cast<cudaStream_t>(stream_ptr)); | |||
| return true; | |||
| } | |||
| bool Init(const CNodePtr &kernel_node) override { | |||
| iou_value_ = GetAttr<float>(kernel_node, "iou_threshold"); | |||
| size_t input_num = AnfAlgo::GetInputTensorNum(kernel_node); | |||
| if (input_num != 1) { | |||
| MS_LOG(ERROR) << "Input number is " << input_num << ", but NMSWithMask needs 1 input."; | |||
| return false; | |||
| } | |||
| size_t output_num = AnfAlgo::GetOutputTensorNum(kernel_node); | |||
| if (output_num != 3) { | |||
| MS_LOG(ERROR) << "Output number is " << output_num << ", but NMSWithMask needs 3 output."; | |||
| return false; | |||
| } | |||
| auto input_shape = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 0); | |||
| if (CHECK_NULL_INPUT(input_shape)) { | |||
| MS_LOG(WARNING) << "NMSWithMask input is null"; | |||
| InitSizeLists(); | |||
| return true; | |||
| } | |||
| num_input_ = input_shape[0]; // Get N value in [N,5] data | |||
| input_size_ = num_input_ * sizeof(T) * box_size_; // 5 values per bbox | |||
| output_size_ = (input_size_) + (num_input_ * sizeof(int)) + (num_input_ * sizeof(bool)); | |||
| workspace_size_ = (2 * num_input_ * sizeof(T)) + (1 * num_input_ * sizeof(int)); | |||
| InitSizeLists(); | |||
| return true; | |||
| } | |||
| protected: | |||
| void InitSizeLists() override { | |||
| // N sized input/output data | |||
| input_size_list_.push_back(num_input_ * sizeof(T) * box_size_); | |||
| output_size_list_.push_back(num_input_ * sizeof(T) * box_size_); | |||
| output_size_list_.push_back(num_input_ * sizeof(int)); | |||
| output_size_list_.push_back(num_input_ * sizeof(bool)); | |||
| // N sized workspace arrs | |||
| workspace_size_list_.push_back(num_input_ * sizeof(T)); | |||
| workspace_size_list_.push_back(num_input_ * sizeof(int)); | |||
| workspace_size_list_.push_back(num_input_ * sizeof(T)); | |||
| } | |||
| private: | |||
| int num_input_; | |||
| float iou_value_; | |||
| static const int box_size_ = 5; // pre_defined box width | |||
| // int box_size__ = 5; // current size of bboxes | |||
| // default values | |||
| size_t input_size_; | |||
| size_t output_size_; | |||
| size_t workspace_size_; | |||
| std::vector<size_t> input_size_list_; | |||
| std::vector<size_t> output_size_list_; | |||
| std::vector<size_t> workspace_size_list_; | |||
| }; | |||
| } // namespace kernel | |||
| } // namespace mindspore | |||
| #endif // MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_MATH_NMS_WITH_MASK_IMPL_H_ | |||
| @@ -0,0 +1,154 @@ | |||
| # Copyright 2020 Huawei Technologies Co., Ltd | |||
| # | |||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||
| # you may not use this file except in compliance with the License. | |||
| # You may obtain a copy of the License at | |||
| # | |||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||
| # | |||
| # Unless required by applicable law or agreed to in writing, software | |||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| # See the License for the specific language governing permissions and | |||
| # limitations under the License. | |||
| # ============================================================================ | |||
| import numpy as np | |||
| import pytest | |||
| import mindspore.context as context | |||
| import mindspore | |||
| from mindspore import Tensor | |||
| from mindspore.ops import operations as P | |||
| def manualNMS(bbox, overlap_val_iou): | |||
| mask = [True] * len(bbox) | |||
| for box_a_index, _ in enumerate(bbox): | |||
| if not mask[box_a_index]: | |||
| continue # ignore if not in list | |||
| box_a = bbox[box_a_index] # select box for value extraction | |||
| for box_b_index in range(box_a_index + 1, len(bbox)): | |||
| if not mask[box_b_index]: | |||
| continue # ignore if not in list | |||
| box_b = bbox[box_b_index] | |||
| areaA = (box_a[2] - box_a[0]) * (box_a[3] - box_a[1]) | |||
| areaB = (box_b[2] - box_b[0]) * (box_b[3] - box_b[1]) | |||
| overlap_x1 = max(box_a[0], box_b[0]) | |||
| overlap_y1 = max(box_a[1], box_b[1]) | |||
| overlap_x2 = min(box_a[2], box_b[2]) | |||
| overlap_y2 = min(box_a[3], box_b[3]) | |||
| width = max((overlap_x2 - overlap_x1), 0) | |||
| height = max((overlap_y2 - overlap_y1), 0) | |||
| # generate IOU decision | |||
| mask[box_b_index] = not ( | |||
| (width * height)/(areaA + areaB - (width * height))) > overlap_val_iou | |||
| return mask | |||
| def runMSRun(op, bbox): | |||
| inputs = Tensor(bbox, mindspore.float32) | |||
| box, _, mask = op(inputs) | |||
| box = box.asnumpy() | |||
| mask = mask.asnumpy() | |||
| sel_idx = np.where(mask) | |||
| sel_rows = box[sel_idx][:, 0:4] | |||
| sel_score = box[sel_idx][:, -1] | |||
| return sel_rows, sel_score | |||
| @pytest.mark.level0 | |||
| @pytest.mark.platform_x86_gpu_training | |||
| @pytest.mark.env_onecard | |||
| def test_nms_with_mask_check_order(): | |||
| context.set_context(mode=context.GRAPH_MODE, device_target="GPU") | |||
| nms_op = P.NMSWithMask(0.5) | |||
| for _ in range(500): | |||
| count = 20 | |||
| box = np.random.randint(1, 100, size=(count, 4)) | |||
| box[:, 2] = box[:, 0] + box[:, 2] | |||
| box[:, 3] = box[:, 1] + box[:, 3] | |||
| unsorted_scores = np.random.rand(count, 1) | |||
| bbox = np.hstack((box, unsorted_scores)) | |||
| bbox = Tensor(bbox, dtype=mindspore.float32) | |||
| prop, _, _ = nms_op(bbox) | |||
| ms_sorted_scores = (prop.asnumpy()[:, -1]) # select just scores | |||
| np_sorted_scores = (np.sort(unsorted_scores, axis=0)[::-1][:, 0]) # sort manually | |||
| np.testing.assert_array_almost_equal( | |||
| ms_sorted_scores, np_sorted_scores) | |||
| @pytest.mark.level0 | |||
| @pytest.mark.platform_x86_gpu_training | |||
| @pytest.mark.env_onecard | |||
| def test_nms_with_masl_check_result(): | |||
| context.set_context(mode=context.GRAPH_MODE, device_target="GPU") | |||
| test_count = 500 | |||
| for x in range(1, test_count+1): | |||
| count = 20 # size of bbox lists | |||
| nms_op = P.NMSWithMask(x * 0.002) # will test full range b/w 0 and 1 | |||
| box = np.random.randint(1, 100, size=(count, 4)) | |||
| box[:, 2] = box[:, 0] + box[:, 2] | |||
| box[:, 3] = box[:, 1] + box[:, 3] | |||
| unsorted_scores = np.random.rand(count, 1) | |||
| sorted_scores = np.sort(unsorted_scores, axis=0)[::-1] | |||
| bbox = np.hstack((box, sorted_scores)) | |||
| bbox = Tensor(bbox, dtype=mindspore.float32) | |||
| _, _, mask = nms_op(bbox) | |||
| mask = mask.asnumpy() | |||
| manual_mask = manualNMS(box, x * 0.002) | |||
| np.testing.assert_array_equal(mask, np.array(manual_mask)) | |||
| @pytest.mark.level0 | |||
| @pytest.mark.platform_x86_gpu_training | |||
| @pytest.mark.env_onecard | |||
| def test_nms_with_mask_edge_case_1(): | |||
| context.set_context(mode=context.GRAPH_MODE, device_target="GPU") | |||
| # CASE 1 - FULL OVERLAP BOXES - Every box is duplicated and has a different score | |||
| nms_op1 = P.NMSWithMask(0.3) | |||
| bbox1 = [[12, 4, 33, 17, 0.6], [20, 11, 38, 23, 0.1], [20, 10, 45, 26, 0.9], [15, 17, 35, 38, 0.5], | |||
| [10, 20, 30, 40, 0.4], [35, 35, 89, 90, 0.8], [12, 4, 33, 17, 0.3], [20, 11, 38, 23, 0.2], | |||
| [20, 10, 45, 26, 0.1], [15, 17, 35, 38, 0.8], [10, 20, 30, 40, 0.41], [35, 35, 89, 90, 0.82]] | |||
| expected_bbox = np.array([[20., 10., 45., 26.], | |||
| [35., 35., 89., 90.], | |||
| [15., 17., 35., 38.], | |||
| [12., 4., 33., 17.]]) | |||
| expected_score = np.array([0.9, 0.82, 0.8, 0.6]) | |||
| sel_rows, sel_score = runMSRun(nms_op1, bbox1) | |||
| np.testing.assert_almost_equal(sel_rows, expected_bbox) | |||
| np.testing.assert_almost_equal(sel_score, expected_score) | |||
| @pytest.mark.level0 | |||
| @pytest.mark.platform_x86_gpu_training | |||
| @pytest.mark.env_onecard | |||
| def test_nms_with_mask_edge_case_2(): | |||
| context.set_context(mode=context.GRAPH_MODE, device_target="GPU") | |||
| # CASE 2 - 0 value boxes - with valid scores | |||
| nms_op2 = P.NMSWithMask(0.5) | |||
| bbox2 = [[0, 0, 0, 0, 0.6], [0, 0, 0, 0, 0.1]] | |||
| expected_bbox = np.array([[0., 0., 0., 0.], | |||
| [0., 0., 0., 0.]]) | |||
| expected_score = np.array([0.6, 0.1]) | |||
| sel_rows, sel_score = runMSRun(nms_op2, bbox2) | |||
| np.testing.assert_almost_equal(sel_rows, expected_bbox) | |||
| np.testing.assert_almost_equal(sel_score, expected_score) | |||
| @pytest.mark.level0 | |||
| @pytest.mark.platform_x86_gpu_training | |||
| @pytest.mark.env_onecard | |||
| def test_nms_with_mask_edge_case_3(): | |||
| context.set_context(mode=context.GRAPH_MODE, device_target="GPU") | |||
| # CASE 3 - x2/x1 and y2/y1 sequence out of place | |||
| nms_op3 = P.NMSWithMask(0.7) | |||
| bbox3 = [[70, 70, 45, 75, 0.6], [30, 33, 43, 29, 0.1]] | |||
| expected_bbox = np.array([[70., 70., 45., 75.], | |||
| [30., 33., 43., 29.]]) | |||
| expected_score = np.array([0.6, 0.1]) | |||
| sel_rows, sel_score = runMSRun(nms_op3, bbox3) | |||
| np.testing.assert_almost_equal(sel_rows, expected_bbox) | |||
| np.testing.assert_almost_equal(sel_score, expected_score) | |||