You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

centernet.py 4.2 kB

2 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111
  1. # Copyright (c) OpenMMLab. All rights reserved.
  2. import torch
  3. from mmdet.core import bbox2result
  4. from mmdet.models.builder import DETECTORS
  5. from ...core.utils import flip_tensor
  6. from .single_stage import SingleStageDetector
  7. @DETECTORS.register_module()
  8. class CenterNet(SingleStageDetector):
  9. """Implementation of CenterNet(Objects as Points)
  10. <https://arxiv.org/abs/1904.07850>.
  11. """
  12. def __init__(self,
  13. backbone,
  14. neck,
  15. bbox_head,
  16. train_cfg=None,
  17. test_cfg=None,
  18. pretrained=None,
  19. init_cfg=None):
  20. super(CenterNet, self).__init__(backbone, neck, bbox_head, train_cfg,
  21. test_cfg, pretrained, init_cfg)
  22. def merge_aug_results(self, aug_results, with_nms):
  23. """Merge augmented detection bboxes and score.
  24. Args:
  25. aug_results (list[list[Tensor]]): Det_bboxes and det_labels of each
  26. image.
  27. with_nms (bool): If True, do nms before return boxes.
  28. Returns:
  29. tuple: (out_bboxes, out_labels)
  30. """
  31. recovered_bboxes, aug_labels = [], []
  32. for single_result in aug_results:
  33. recovered_bboxes.append(single_result[0][0])
  34. aug_labels.append(single_result[0][1])
  35. bboxes = torch.cat(recovered_bboxes, dim=0).contiguous()
  36. labels = torch.cat(aug_labels).contiguous()
  37. if with_nms:
  38. out_bboxes, out_labels = self.bbox_head._bboxes_nms(
  39. bboxes, labels, self.bbox_head.test_cfg)
  40. else:
  41. out_bboxes, out_labels = bboxes, labels
  42. return out_bboxes, out_labels
  43. def aug_test(self, imgs, img_metas, rescale=True):
  44. """Augment testing of CenterNet. Aug test must have flipped image pair,
  45. and unlike CornerNet, it will perform an averaging operation on the
  46. feature map instead of detecting bbox.
  47. Args:
  48. imgs (list[Tensor]): Augmented images.
  49. img_metas (list[list[dict]]): Meta information of each image, e.g.,
  50. image size, scaling factor, etc.
  51. rescale (bool): If True, return boxes in original image space.
  52. Default: True.
  53. Note:
  54. ``imgs`` must including flipped image pairs.
  55. Returns:
  56. list[list[np.ndarray]]: BBox results of each image and classes.
  57. The outer list corresponds to each image. The inner list
  58. corresponds to each class.
  59. """
  60. img_inds = list(range(len(imgs)))
  61. assert img_metas[0][0]['flip'] + img_metas[1][0]['flip'], (
  62. 'aug test must have flipped image pair')
  63. aug_results = []
  64. for ind, flip_ind in zip(img_inds[0::2], img_inds[1::2]):
  65. flip_direction = img_metas[flip_ind][0]['flip_direction']
  66. img_pair = torch.cat([imgs[ind], imgs[flip_ind]])
  67. x = self.extract_feat(img_pair)
  68. center_heatmap_preds, wh_preds, offset_preds = self.bbox_head(x)
  69. assert len(center_heatmap_preds) == len(wh_preds) == len(
  70. offset_preds) == 1
  71. # Feature map averaging
  72. center_heatmap_preds[0] = (
  73. center_heatmap_preds[0][0:1] +
  74. flip_tensor(center_heatmap_preds[0][1:2], flip_direction)) / 2
  75. wh_preds[0] = (wh_preds[0][0:1] +
  76. flip_tensor(wh_preds[0][1:2], flip_direction)) / 2
  77. bbox_list = self.bbox_head.get_bboxes(
  78. center_heatmap_preds,
  79. wh_preds, [offset_preds[0][0:1]],
  80. img_metas[ind],
  81. rescale=rescale,
  82. with_nms=False)
  83. aug_results.append(bbox_list)
  84. nms_cfg = self.bbox_head.test_cfg.get('nms_cfg', None)
  85. if nms_cfg is None:
  86. with_nms = False
  87. else:
  88. with_nms = True
  89. bbox_list = [self.merge_aug_results(aug_results, with_nms)]
  90. bbox_results = [
  91. bbox2result(det_bboxes, det_labels, self.bbox_head.num_classes)
  92. for det_bboxes, det_labels in bbox_list
  93. ]
  94. return bbox_results

No Description

Contributors (3)