| @@ -18,7 +18,7 @@ | |||
| #include <limits> | |||
| #include <algorithm> | |||
| int RoundUpPower2M(int v) { | |||
| int NMSRoundUpPower2(int v) { | |||
| v--; | |||
| v |= v >> 1; | |||
| v |= v >> 2; | |||
| @@ -30,12 +30,22 @@ int RoundUpPower2M(int v) { | |||
| } | |||
| template <typename T> | |||
| __inline__ __device__ void SwapM(T *lhs, T *rhs) { | |||
| __inline__ __device__ void Swap(T *lhs, T *rhs) { | |||
| T tmp = lhs[0]; | |||
| lhs[0] = rhs[0]; | |||
| rhs[0] = tmp; | |||
| } | |||
| template <typename T> | |||
| __global__ void PopulateOutput(T *data_in, T *data_out, int *index_buff, const int num, int box_size_) { | |||
| for (int box_num = blockIdx.x * blockDim.x + threadIdx.x; box_num < num; box_num += blockDim.x * gridDim.x) { | |||
| int correct_index = index_buff[(num - 1) - box_num]; // flip the array around | |||
| for (int x = 0; x < 5; x++) { | |||
| data_out[(box_num * box_size_) + x] = data_in[(correct_index * box_size_) + x]; | |||
| } | |||
| } | |||
| } | |||
| template <typename T> | |||
| __inline__ __device__ bool IOUDecision(T *output, int box_A_ix, int box_B_ix, int box_A_start, int box_B_start, T *area, | |||
| float IOU_value) { | |||
| @@ -96,38 +106,29 @@ __global__ void FinalPass(const int num, const float IOU_value, T *output, T *ar | |||
| } | |||
| } | |||
| template <typename T, typename S> | |||
| __global__ void BitonicSortByKeyKernelM(const int outer, const int inner, const int ceil_power2, S *data_in, | |||
| S *data_out, T *index_buff, S *data_buff, int box_size_) { | |||
| // default: sort with share memory | |||
| extern __shared__ T share_mem_NMS[]; | |||
| T *index_arr = share_mem_NMS; | |||
| S *data_arr = reinterpret_cast<S *>(index_arr + ceil_power2); | |||
| // sort with RAM | |||
| if (index_buff != nullptr && data_buff != nullptr) { | |||
| index_arr = index_buff + blockIdx.x * ceil_power2; | |||
| data_arr = data_buff + blockIdx.x * ceil_power2; | |||
| } | |||
| template <typename T> | |||
| __global__ void NMS_BitonicSortByKeyKernel(const int outer, const int inner, const int ceil_power2, T *input, | |||
| T *data_buff, int *index_buff, int box_size_) { | |||
| for (int i = threadIdx.x; i < ceil_power2; i += blockDim.x) { | |||
| index_arr[i] = (i < inner) ? T(i) : std::numeric_limits<T>::max(); | |||
| // populated directly from input data | |||
| data_arr[i] = (i < inner) ? data_in[(blockIdx.x * inner + i) * box_size_ + 4] : std::numeric_limits<S>::max(); | |||
| data_buff[i] = (i < inner) ? input[(i * box_size_) + 4] : std::numeric_limits<T>::max(); | |||
| index_buff[i] = i; | |||
| } | |||
| __syncthreads(); | |||
| for (size_t i = 2; i <= ceil_power2; i <<= 1) { | |||
| for (size_t j = (i >> 1); j > 0; j >>= 1) { | |||
| for (size_t tid = threadIdx.x; tid < ceil_power2; tid += blockDim.x) { | |||
| size_t tid_comp = tid ^ j; | |||
| if (tid_comp > tid) { | |||
| if ((tid & i) == 0) { | |||
| if (data_arr[tid] > data_arr[tid_comp]) { | |||
| SwapM(&index_arr[tid], &index_arr[tid_comp]); | |||
| SwapM(&data_arr[tid], &data_arr[tid_comp]); | |||
| if (data_buff[tid] > data_buff[tid_comp]) { | |||
| Swap(&data_buff[tid], &data_buff[tid_comp]); | |||
| Swap(&index_buff[tid], &index_buff[tid_comp]); | |||
| } | |||
| } else { | |||
| if (data_arr[tid] < data_arr[tid_comp]) { | |||
| SwapM(&index_arr[tid], &index_arr[tid_comp]); | |||
| SwapM(&data_arr[tid], &data_arr[tid_comp]); | |||
| if (data_buff[tid] < data_buff[tid_comp]) { | |||
| Swap(&data_buff[tid], &data_buff[tid_comp]); | |||
| Swap(&index_buff[tid], &index_buff[tid_comp]); | |||
| } | |||
| } | |||
| } | |||
| @@ -135,36 +136,21 @@ __global__ void BitonicSortByKeyKernelM(const int outer, const int inner, const | |||
| __syncthreads(); | |||
| } | |||
| } | |||
| T correct_index; | |||
| for (size_t tid = threadIdx.x; tid < inner; tid += blockDim.x) { | |||
| correct_index = index_arr[(inner - 1) - tid]; | |||
| // moved data from input to output, correct ordering using sorted index array | |||
| for (auto i : {0, 1, 2, 3, 4}) { | |||
| data_out[(blockIdx.x * inner + tid) * box_size_ + i] = | |||
| data_in[(blockIdx.x * inner + correct_index) * box_size_ + i]; | |||
| } | |||
| } | |||
| } | |||
| template <typename T> | |||
| void CalPreprocess(const int num, int *sel_idx, T *area, T *output, int box_size_, cudaStream_t cuda_stream) { | |||
| void CalPreprocess(const int num, int *sel_idx, T *area, T *input, T *output, int *index_buff, int box_size_, | |||
| cudaStream_t cuda_stream) { | |||
| PopulateOutput<<<GET_BLOCKS(num), GET_THREADS, 0, cuda_stream>>>(input, output, index_buff, num, box_size_); | |||
| Preprocess<<<GET_BLOCKS(num), GET_THREADS, 0, cuda_stream>>>(num, sel_idx, area, output, box_size_); | |||
| } | |||
| template <typename T, typename S> | |||
| void BitonicSortByKeyM(const int &outer, const int &inner, S *data_in, S *data_out, T *index_buff, S *data_buff, | |||
| int box_size_, cudaStream_t stream) { | |||
| int ceil_power2 = RoundUpPower2M(inner); | |||
| size_t share_mem = ceil_power2 * (sizeof(T) + sizeof(S)); | |||
| if (share_mem > SHARED_MEM_PER_BLOCK) { | |||
| share_mem = 0; | |||
| } else { | |||
| data_buff = nullptr; | |||
| index_buff = nullptr; | |||
| } | |||
| int thread = std::min(ceil_power2, GET_THREADS); | |||
| BitonicSortByKeyKernelM<<<outer, thread, share_mem, stream>>>(outer, inner, ceil_power2, data_in, data_out, | |||
| index_buff, data_buff, box_size_); | |||
| template <typename T> | |||
| void CalSortInit(const int &num, T *data_in, T *data_out, int *index_buff, T *data_buff, int box_size_, | |||
| cudaStream_t stream) { | |||
| int ceil_p_2 = NMSRoundUpPower2(num); | |||
| int thread = std::min(ceil_p_2, GET_THREADS); | |||
| NMS_BitonicSortByKeyKernel<<<1, thread, 0, stream>>>(1, num, ceil_p_2, data_in, data_buff, index_buff, box_size_); | |||
| } | |||
| template <typename T> | |||
| @@ -180,11 +166,11 @@ void CalFinalPass(const int num, const float IOU_value, T *output, T *area, bool | |||
| FinalPass<<<1, 1, 0, cuda_stream>>>(num, IOU_value, output, area, sel_boxes, box_size_); | |||
| } | |||
| template void CalPreprocess<float>(const int num, int *sel_idx, float *area, float *output, int box_size_, | |||
| cudaStream_t cuda_stream); | |||
| template void CalPreprocess<float>(const int num, int *sel_idx, float *area, float *input, float *output, | |||
| int *index_buff, int box_size_, cudaStream_t cuda_stream); | |||
| template void BitonicSortByKeyM(const int &outer, const int &inner, float *data_in, float *data_out, int *index_buff, | |||
| float *data_buff, int box_size_, cudaStream_t stream); | |||
| template void CalSortInit<float>(const int &inner, float *data_in, float *data_out, int *index_buff, float *data_buff, | |||
| int box_size_, cudaStream_t stream); | |||
| template void CalNMSWithMask<float>(const int num, const float IOU_value, float *output, float *area, bool *sel_boxes, | |||
| int box_size_, cudaStream_t cuda_stream); | |||
| @@ -20,18 +20,21 @@ | |||
| #include "runtime/device/gpu/cuda_common.h" | |||
| template <typename T> | |||
| void CalPreprocess(const int num, int *sel_idx, T *area, T *output, int box_size_, cudaStream_t cuda_stream); | |||
| void CalPreprocess(const int num, int *sel_idx, T *area, T *input, T *output, int *index_buff, int box_size_, | |||
| cudaStream_t cuda_stream); | |||
| template <typename T> | |||
| void CalNMSWithMask(const int num, const float IOU_value, T *output, T *area, bool *sel_boxes, int box_size_, | |||
| cudaStream_t cuda_stream); | |||
| template <typename T, typename S> | |||
| void BitonicSortByKeyM(const int &outer, const int &inner, S *data_in, S *data_out, T *index_buff, S *data_buff, | |||
| int box_size_, cudaStream_t stream); | |||
| template <typename T> | |||
| void CalSortInit(const int &inner, T *data_in, T *data_out, int *index_buff, T *data_buff, int box_size_, | |||
| cudaStream_t stream); | |||
| template <typename T> | |||
| void CalFinalPass(const int num, const float IOU_value, T *output, T *area, bool *sel_boxes, int box_size_, | |||
| cudaStream_t cuda_stream); | |||
| int NMSRoundUpPower2(int v); | |||
| #endif // MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_CUDA_IMPL_NMS_WITH_MASK_IMPL_H_ | |||
| @@ -30,7 +30,8 @@ namespace kernel { | |||
| template <typename T> | |||
| class NMSWithMaskGpuFwdKernel : public GpuKernel { | |||
| public: | |||
| NMSWithMaskGpuFwdKernel() : num_input_(0), iou_value_(0.5), input_size_(0), output_size_(0), workspace_size_(0) {} | |||
| NMSWithMaskGpuFwdKernel() | |||
| : num_input_(0), iou_value_(0.5), input_size_(0), output_size_(0), workspace_size_(0), ceil_power_2(0) {} | |||
| ~NMSWithMaskGpuFwdKernel() override = default; | |||
| const std::vector<size_t> &GetInputSizeList() const override { return input_size_list_; } | |||
| @@ -40,22 +41,24 @@ class NMSWithMaskGpuFwdKernel : public GpuKernel { | |||
| bool Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &workspace, | |||
| const std::vector<AddressPtr> &outputs, void *stream_ptr) override { | |||
| T *input = GetDeviceAddress<T>(inputs, 0); | |||
| T *data_buff = GetDeviceAddress<T>(workspace, 0); // sort buffer | |||
| int *index_buff = GetDeviceAddress<int>(workspace, 1); | |||
| T *area = GetDeviceAddress<T>(workspace, 2); // store area values for all boxes | |||
| T *area = GetDeviceAddress<T>(workspace, 0); // store area values for all boxes | |||
| T *data_buff = GetDeviceAddress<T>(workspace, 1); // sort buffer | |||
| int *index_buff = GetDeviceAddress<int>(workspace, 2); | |||
| T *output = GetDeviceAddress<T>(outputs, 0); | |||
| int *sel_idx = GetDeviceAddress<int>(outputs, 1); | |||
| bool *sel_boxes = GetDeviceAddress<bool>(outputs, 2); | |||
| BitonicSortByKeyM(num_input_, num_input_, input, output, index_buff, data_buff, box_size_, | |||
| reinterpret_cast<cudaStream_t>(stream_ptr)); | |||
| CalPreprocess(num_input_, sel_idx, area, output, box_size_, reinterpret_cast<cudaStream_t>(stream_ptr)); | |||
| CalSortInit(num_input_, input, output, index_buff, data_buff, box_size_, | |||
| reinterpret_cast<cudaStream_t>(stream_ptr)); | |||
| CalPreprocess(num_input_, sel_idx, area, input, output, index_buff, box_size_, | |||
| reinterpret_cast<cudaStream_t>(stream_ptr)); | |||
| CalNMSWithMask(num_input_, iou_value_, output, area, sel_boxes, box_size_, | |||
| reinterpret_cast<cudaStream_t>(stream_ptr)); | |||
| CalFinalPass(num_input_, iou_value_, output, area, sel_boxes, box_size_, | |||
| reinterpret_cast<cudaStream_t>(stream_ptr)); | |||
| return true; | |||
| } | |||
| bool Init(const CNodePtr &kernel_node) override { | |||
| iou_value_ = GetAttr<float>(kernel_node, "iou_threshold"); | |||
| @@ -79,10 +82,13 @@ class NMSWithMaskGpuFwdKernel : public GpuKernel { | |||
| } | |||
| num_input_ = input_shape[0]; // Get N value in [N,5] data | |||
| ceil_power_2 = NMSRoundUpPower2(num_input_); | |||
| input_size_ = num_input_ * sizeof(T) * box_size_; // 5 values per bbox | |||
| output_size_ = (input_size_) + (num_input_ * sizeof(int)) + (num_input_ * sizeof(bool)); | |||
| workspace_size_ = (2 * num_input_ * sizeof(T)) + (1 * num_input_ * sizeof(int)); | |||
| workspace_size_ = num_input_ * sizeof(int); | |||
| workspace_size_ += ceil_power_2 * (sizeof(T) + sizeof(int)); | |||
| InitSizeLists(); | |||
| return true; | |||
| @@ -97,20 +103,20 @@ class NMSWithMaskGpuFwdKernel : public GpuKernel { | |||
| output_size_list_.push_back(num_input_ * sizeof(bool)); | |||
| // N sized workspace arrs | |||
| workspace_size_list_.push_back(num_input_ * sizeof(T)); | |||
| workspace_size_list_.push_back(num_input_ * sizeof(int)); | |||
| workspace_size_list_.push_back(num_input_ * sizeof(T)); | |||
| workspace_size_list_.push_back(num_input_ * sizeof(T)); // area list | |||
| workspace_size_list_.push_back(ceil_power_2 * sizeof(T)); // data buff | |||
| workspace_size_list_.push_back(ceil_power_2 * sizeof(int)); // index buff | |||
| } | |||
| private: | |||
| int num_input_; | |||
| float iou_value_; | |||
| static const int box_size_ = 5; // pre_defined box width | |||
| // int box_size__ = 5; // current size of bboxes | |||
| // default values | |||
| size_t input_size_; | |||
| size_t output_size_; | |||
| size_t workspace_size_; | |||
| size_t ceil_power_2; | |||
| std::vector<size_t> input_size_list_; | |||
| std::vector<size_t> output_size_list_; | |||
| std::vector<size_t> workspace_size_list_; | |||
| @@ -21,29 +21,6 @@ import mindspore | |||
| from mindspore import Tensor | |||
| from mindspore.ops import operations as P | |||
| def manualNMS(bbox, overlap_val_iou): | |||
| mask = [True] * len(bbox) | |||
| for box_a_index, _ in enumerate(bbox): | |||
| if not mask[box_a_index]: | |||
| continue # ignore if not in list | |||
| box_a = bbox[box_a_index] # select box for value extraction | |||
| for box_b_index in range(box_a_index + 1, len(bbox)): | |||
| if not mask[box_b_index]: | |||
| continue # ignore if not in list | |||
| box_b = bbox[box_b_index] | |||
| areaA = (box_a[2] - box_a[0]) * (box_a[3] - box_a[1]) | |||
| areaB = (box_b[2] - box_b[0]) * (box_b[3] - box_b[1]) | |||
| overlap_x1 = max(box_a[0], box_b[0]) | |||
| overlap_y1 = max(box_a[1], box_b[1]) | |||
| overlap_x2 = min(box_a[2], box_b[2]) | |||
| overlap_y2 = min(box_a[3], box_b[3]) | |||
| width = max((overlap_x2 - overlap_x1), 0) | |||
| height = max((overlap_y2 - overlap_y1), 0) | |||
| # generate IOU decision | |||
| mask[box_b_index] = not ( | |||
| (width * height)/(areaA + areaB - (width * height))) > overlap_val_iou | |||
| return mask | |||
| def runMSRun(op, bbox): | |||
| inputs = Tensor(bbox, mindspore.float32) | |||
| @@ -60,10 +37,10 @@ def runMSRun(op, bbox): | |||
| @pytest.mark.platform_x86_gpu_training | |||
| @pytest.mark.env_onecard | |||
| def test_nms_with_mask_check_order(): | |||
| context.set_context(mode=context.GRAPH_MODE, device_target="GPU") | |||
| context.set_context(mode=context.PYNATIVE_MODE, device_target="GPU") | |||
| nms_op = P.NMSWithMask(0.5) | |||
| for _ in range(500): | |||
| count = 20 | |||
| for _ in range(10): | |||
| count = 8000 | |||
| box = np.random.randint(1, 100, size=(count, 4)) | |||
| box[:, 2] = box[:, 0] + box[:, 2] | |||
| box[:, 3] = box[:, 1] + box[:, 3] | |||
| @@ -77,28 +54,6 @@ def test_nms_with_mask_check_order(): | |||
| ms_sorted_scores, np_sorted_scores) | |||
| @pytest.mark.level0 | |||
| @pytest.mark.platform_x86_gpu_training | |||
| @pytest.mark.env_onecard | |||
| def test_nms_with_masl_check_result(): | |||
| context.set_context(mode=context.GRAPH_MODE, device_target="GPU") | |||
| test_count = 500 | |||
| for x in range(1, test_count+1): | |||
| count = 20 # size of bbox lists | |||
| nms_op = P.NMSWithMask(x * 0.002) # will test full range b/w 0 and 1 | |||
| box = np.random.randint(1, 100, size=(count, 4)) | |||
| box[:, 2] = box[:, 0] + box[:, 2] | |||
| box[:, 3] = box[:, 1] + box[:, 3] | |||
| unsorted_scores = np.random.rand(count, 1) | |||
| sorted_scores = np.sort(unsorted_scores, axis=0)[::-1] | |||
| bbox = np.hstack((box, sorted_scores)) | |||
| bbox = Tensor(bbox, dtype=mindspore.float32) | |||
| _, _, mask = nms_op(bbox) | |||
| mask = mask.asnumpy() | |||
| manual_mask = manualNMS(box, x * 0.002) | |||
| np.testing.assert_array_equal(mask, np.array(manual_mask)) | |||
| @pytest.mark.level0 | |||
| @pytest.mark.platform_x86_gpu_training | |||
| @pytest.mark.env_onecard | |||