You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

cornernet.py 3.7 kB

2 years ago
12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788899091929394959697
  1. # Copyright (c) OpenMMLab. All rights reserved.
  2. import torch
  3. from mmdet.core import bbox2result, bbox_mapping_back
  4. from ..builder import DETECTORS
  5. from .single_stage import SingleStageDetector
  6. @DETECTORS.register_module()
  7. class CornerNet(SingleStageDetector):
  8. """CornerNet.
  9. This detector is the implementation of the paper `CornerNet: Detecting
  10. Objects as Paired Keypoints <https://arxiv.org/abs/1808.01244>`_ .
  11. """
  12. def __init__(self,
  13. backbone,
  14. neck,
  15. bbox_head,
  16. train_cfg=None,
  17. test_cfg=None,
  18. pretrained=None,
  19. init_cfg=None):
  20. super(CornerNet, self).__init__(backbone, neck, bbox_head, train_cfg,
  21. test_cfg, pretrained, init_cfg)
  22. def merge_aug_results(self, aug_results, img_metas):
  23. """Merge augmented detection bboxes and score.
  24. Args:
  25. aug_results (list[list[Tensor]]): Det_bboxes and det_labels of each
  26. image.
  27. img_metas (list[list[dict]]): Meta information of each image, e.g.,
  28. image size, scaling factor, etc.
  29. Returns:
  30. tuple: (bboxes, labels)
  31. """
  32. recovered_bboxes, aug_labels = [], []
  33. for bboxes_labels, img_info in zip(aug_results, img_metas):
  34. img_shape = img_info[0]['img_shape'] # using shape before padding
  35. scale_factor = img_info[0]['scale_factor']
  36. flip = img_info[0]['flip']
  37. bboxes, labels = bboxes_labels
  38. bboxes, scores = bboxes[:, :4], bboxes[:, -1:]
  39. bboxes = bbox_mapping_back(bboxes, img_shape, scale_factor, flip)
  40. recovered_bboxes.append(torch.cat([bboxes, scores], dim=-1))
  41. aug_labels.append(labels)
  42. bboxes = torch.cat(recovered_bboxes, dim=0)
  43. labels = torch.cat(aug_labels)
  44. if bboxes.shape[0] > 0:
  45. out_bboxes, out_labels = self.bbox_head._bboxes_nms(
  46. bboxes, labels, self.bbox_head.test_cfg)
  47. else:
  48. out_bboxes, out_labels = bboxes, labels
  49. return out_bboxes, out_labels
  50. def aug_test(self, imgs, img_metas, rescale=False):
  51. """Augment testing of CornerNet.
  52. Args:
  53. imgs (list[Tensor]): Augmented images.
  54. img_metas (list[list[dict]]): Meta information of each image, e.g.,
  55. image size, scaling factor, etc.
  56. rescale (bool): If True, return boxes in original image space.
  57. Default: False.
  58. Note:
  59. ``imgs`` must including flipped image pairs.
  60. Returns:
  61. list[list[np.ndarray]]: BBox results of each image and classes.
  62. The outer list corresponds to each image. The inner list
  63. corresponds to each class.
  64. """
  65. img_inds = list(range(len(imgs)))
  66. assert img_metas[0][0]['flip'] + img_metas[1][0]['flip'], (
  67. 'aug test must have flipped image pair')
  68. aug_results = []
  69. for ind, flip_ind in zip(img_inds[0::2], img_inds[1::2]):
  70. img_pair = torch.cat([imgs[ind], imgs[flip_ind]])
  71. x = self.extract_feat(img_pair)
  72. outs = self.bbox_head(x)
  73. bbox_list = self.bbox_head.get_bboxes(
  74. *outs, [img_metas[ind], img_metas[flip_ind]], False, False)
  75. aug_results.append(bbox_list[0])
  76. aug_results.append(bbox_list[1])
  77. bboxes, labels = self.merge_aug_results(aug_results, img_metas)
  78. bbox_results = bbox2result(bboxes, labels, self.bbox_head.num_classes)
  79. return [bbox_results]

No Description

Contributors (1)