You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

base_dense_head.py 23 kB

2 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518
  1. # Copyright (c) OpenMMLab. All rights reserved.
  2. from abc import ABCMeta, abstractmethod
  3. import torch
  4. from mmcv.ops import batched_nms
  5. from mmcv.runner import BaseModule, force_fp32
  6. from mmdet.core.utils import filter_scores_and_topk, select_single_mlvl
  7. class BaseDenseHead(BaseModule, metaclass=ABCMeta):
  8. """Base class for DenseHeads."""
  9. def __init__(self, init_cfg=None):
  10. super(BaseDenseHead, self).__init__(init_cfg)
  11. @abstractmethod
  12. def loss(self, **kwargs):
  13. """Compute losses of the head."""
  14. pass
  15. @force_fp32(apply_to=('cls_scores', 'bbox_preds'))
  16. def get_bboxes(self,
  17. cls_scores,
  18. bbox_preds,
  19. score_factors=None,
  20. img_metas=None,
  21. cfg=None,
  22. rescale=False,
  23. with_nms=True,
  24. **kwargs):
  25. """Transform network outputs of a batch into bbox results.
  26. Note: When score_factors is not None, the cls_scores are
  27. usually multiplied by it then obtain the real score used in NMS,
  28. such as CenterNess in FCOS, IoU branch in ATSS.
  29. Args:
  30. cls_scores (list[Tensor]): Classification scores for all
  31. scale levels, each is a 4D-tensor, has shape
  32. (batch_size, num_priors * num_classes, H, W).
  33. bbox_preds (list[Tensor]): Box energies / deltas for all
  34. scale levels, each is a 4D-tensor, has shape
  35. (batch_size, num_priors * 4, H, W).
  36. score_factors (list[Tensor], Optional): Score factor for
  37. all scale level, each is a 4D-tensor, has shape
  38. (batch_size, num_priors * 1, H, W). Default None.
  39. img_metas (list[dict], Optional): Image meta info. Default None.
  40. cfg (mmcv.Config, Optional): Test / postprocessing configuration,
  41. if None, test_cfg would be used. Default None.
  42. rescale (bool): If True, return boxes in original image space.
  43. Default False.
  44. with_nms (bool): If True, do nms before return boxes.
  45. Default True.
  46. Returns:
  47. list[list[Tensor, Tensor]]: Each item in result_list is 2-tuple.
  48. The first item is an (n, 5) tensor, where the first 4 columns
  49. are bounding box positions (tl_x, tl_y, br_x, br_y) and the
  50. 5-th column is a score between 0 and 1. The second item is a
  51. (n,) tensor where each item is the predicted class label of
  52. the corresponding box.
  53. """
  54. assert len(cls_scores) == len(bbox_preds)
  55. if score_factors is None:
  56. # e.g. Retina, FreeAnchor, Foveabox, etc.
  57. with_score_factors = False
  58. else:
  59. # e.g. FCOS, PAA, ATSS, AutoAssign, etc.
  60. with_score_factors = True
  61. assert len(cls_scores) == len(score_factors)
  62. num_levels = len(cls_scores)
  63. featmap_sizes = [cls_scores[i].shape[-2:] for i in range(num_levels)]
  64. mlvl_priors = self.prior_generator.grid_priors(
  65. featmap_sizes,
  66. dtype=cls_scores[0].device,
  67. device=cls_scores[0].device)
  68. result_list = []
  69. for img_id in range(len(img_metas)):
  70. img_meta = img_metas[img_id]
  71. cls_score_list = select_single_mlvl(cls_scores, img_id)
  72. bbox_pred_list = select_single_mlvl(bbox_preds, img_id)
  73. if with_score_factors:
  74. score_factor_list = select_single_mlvl(score_factors, img_id)
  75. else:
  76. score_factor_list = [None for _ in range(num_levels)]
  77. results = self._get_bboxes_single(cls_score_list, bbox_pred_list,
  78. score_factor_list, mlvl_priors,
  79. img_meta, cfg, rescale, with_nms,
  80. **kwargs)
  81. result_list.append(results)
  82. return result_list
  83. def _get_bboxes_single(self,
  84. cls_score_list,
  85. bbox_pred_list,
  86. score_factor_list,
  87. mlvl_priors,
  88. img_meta,
  89. cfg,
  90. rescale=False,
  91. with_nms=True,
  92. **kwargs):
  93. """Transform outputs of a single image into bbox predictions.
  94. Args:
  95. cls_score_list (list[Tensor]): Box scores from all scale
  96. levels of a single image, each item has shape
  97. (num_priors * num_classes, H, W).
  98. bbox_pred_list (list[Tensor]): Box energies / deltas from
  99. all scale levels of a single image, each item has shape
  100. (num_priors * 4, H, W).
  101. score_factor_list (list[Tensor]): Score factor from all scale
  102. levels of a single image, each item has shape
  103. (num_priors * 1, H, W).
  104. mlvl_priors (list[Tensor]): Each element in the list is
  105. the priors of a single level in feature pyramid. In all
  106. anchor-based methods, it has shape (num_priors, 4). In
  107. all anchor-free methods, it has shape (num_priors, 2)
  108. when `with_stride=True`, otherwise it still has shape
  109. (num_priors, 4).
  110. img_meta (dict): Image meta info.
  111. cfg (mmcv.Config): Test / postprocessing configuration,
  112. if None, test_cfg would be used.
  113. rescale (bool): If True, return boxes in original image space.
  114. Default: False.
  115. with_nms (bool): If True, do nms before return boxes.
  116. Default: True.
  117. Returns:
  118. tuple[Tensor]: Results of detected bboxes and labels. If with_nms
  119. is False and mlvl_score_factor is None, return mlvl_bboxes and
  120. mlvl_scores, else return mlvl_bboxes, mlvl_scores and
  121. mlvl_score_factor. Usually with_nms is False is used for aug
  122. test. If with_nms is True, then return the following format
  123. - det_bboxes (Tensor): Predicted bboxes with shape \
  124. [num_bboxes, 5], where the first 4 columns are bounding \
  125. box positions (tl_x, tl_y, br_x, br_y) and the 5-th \
  126. column are scores between 0 and 1.
  127. - det_labels (Tensor): Predicted labels of the corresponding \
  128. box with shape [num_bboxes].
  129. """
  130. if score_factor_list[0] is None:
  131. # e.g. Retina, FreeAnchor, etc.
  132. with_score_factors = False
  133. else:
  134. # e.g. FCOS, PAA, ATSS, etc.
  135. with_score_factors = True
  136. cfg = self.test_cfg if cfg is None else cfg
  137. img_shape = img_meta['img_shape']
  138. nms_pre = cfg.get('nms_pre', -1)
  139. mlvl_bboxes = []
  140. mlvl_scores = []
  141. mlvl_labels = []
  142. if with_score_factors:
  143. mlvl_score_factors = []
  144. else:
  145. mlvl_score_factors = None
  146. for level_idx, (cls_score, bbox_pred, score_factor, priors) in \
  147. enumerate(zip(cls_score_list, bbox_pred_list,
  148. score_factor_list, mlvl_priors)):
  149. assert cls_score.size()[-2:] == bbox_pred.size()[-2:]
  150. bbox_pred = bbox_pred.permute(1, 2, 0).reshape(-1, 4)
  151. if with_score_factors:
  152. score_factor = score_factor.permute(1, 2,
  153. 0).reshape(-1).sigmoid()
  154. cls_score = cls_score.permute(1, 2,
  155. 0).reshape(-1, self.cls_out_channels)
  156. if self.use_sigmoid_cls:
  157. scores = cls_score.sigmoid()
  158. else:
  159. # remind that we set FG labels to [0, num_class-1]
  160. # since mmdet v2.0
  161. # BG cat_id: num_class
  162. scores = cls_score.softmax(-1)
  163. scores = cls_score.softmax(-1)[:, :-1]
  164. # After https://github.com/open-mmlab/mmdetection/pull/6268/,
  165. # this operation keeps fewer bboxes under the same `nms_pre`.
  166. # There is no difference in performance for most models. If you
  167. # find a slight drop in performance, you can set a larger
  168. # `nms_pre` than before.
  169. results = filter_scores_and_topk(
  170. scores, cfg.score_thr, nms_pre,
  171. dict(bbox_pred=bbox_pred, priors=priors))
  172. scores, labels, keep_idxs, filtered_results = results
  173. bbox_pred = filtered_results['bbox_pred']
  174. priors = filtered_results['priors']
  175. if with_score_factors:
  176. score_factor = score_factor[keep_idxs]
  177. bboxes = self.bbox_coder.decode(
  178. priors, bbox_pred, max_shape=img_shape)
  179. mlvl_bboxes.append(bboxes)
  180. mlvl_scores.append(scores)
  181. mlvl_labels.append(labels)
  182. if with_score_factors:
  183. mlvl_score_factors.append(score_factor)
  184. return self._bbox_post_process(mlvl_scores, mlvl_labels, mlvl_bboxes,
  185. img_meta['scale_factor'], cfg, rescale,
  186. with_nms, mlvl_score_factors, **kwargs)
  187. def _bbox_post_process(self,
  188. mlvl_scores,
  189. mlvl_labels,
  190. mlvl_bboxes,
  191. scale_factor,
  192. cfg,
  193. rescale=False,
  194. with_nms=True,
  195. mlvl_score_factors=None,
  196. **kwargs):
  197. """bbox post-processing method.
  198. The boxes would be rescaled to the original image scale and do
  199. the nms operation. Usually with_nms is False is used for aug test.
  200. Args:
  201. mlvl_scores (list[Tensor]): Box scores from all scale
  202. levels of a single image, each item has shape
  203. (num_bboxes, ).
  204. mlvl_labels (list[Tensor]): Box class labels from all scale
  205. levels of a single image, each item has shape
  206. (num_bboxes, ).
  207. mlvl_bboxes (list[Tensor]): Decoded bboxes from all scale
  208. levels of a single image, each item has shape (num_bboxes, 4).
  209. scale_factor (ndarray, optional): Scale factor of the image arange
  210. as (w_scale, h_scale, w_scale, h_scale).
  211. cfg (mmcv.Config): Test / postprocessing configuration,
  212. if None, test_cfg would be used.
  213. rescale (bool): If True, return boxes in original image space.
  214. Default: False.
  215. with_nms (bool): If True, do nms before return boxes.
  216. Default: True.
  217. mlvl_score_factors (list[Tensor], optional): Score factor from
  218. all scale levels of a single image, each item has shape
  219. (num_bboxes, ). Default: None.
  220. Returns:
  221. tuple[Tensor]: Results of detected bboxes and labels. If with_nms
  222. is False and mlvl_score_factor is None, return mlvl_bboxes and
  223. mlvl_scores, else return mlvl_bboxes, mlvl_scores and
  224. mlvl_score_factor. Usually with_nms is False is used for aug
  225. test. If with_nms is True, then return the following format
  226. - det_bboxes (Tensor): Predicted bboxes with shape \
  227. [num_bboxes, 5], where the first 4 columns are bounding \
  228. box positions (tl_x, tl_y, br_x, br_y) and the 5-th \
  229. column are scores between 0 and 1.
  230. - det_labels (Tensor): Predicted labels of the corresponding \
  231. box with shape [num_bboxes].
  232. """
  233. assert len(mlvl_scores) == len(mlvl_bboxes) == len(mlvl_labels)
  234. mlvl_bboxes = torch.cat(mlvl_bboxes)
  235. if rescale:
  236. mlvl_bboxes /= mlvl_bboxes.new_tensor(scale_factor)
  237. mlvl_scores = torch.cat(mlvl_scores)
  238. mlvl_labels = torch.cat(mlvl_labels)
  239. if mlvl_score_factors is not None:
  240. # TODO: Add sqrt operation in order to be consistent with
  241. # the paper.
  242. mlvl_score_factors = torch.cat(mlvl_score_factors)
  243. mlvl_scores = mlvl_scores * mlvl_score_factors
  244. if with_nms:
  245. if mlvl_bboxes.numel() == 0:
  246. det_bboxes = torch.cat([mlvl_bboxes, mlvl_scores[:, None]], -1)
  247. return det_bboxes, mlvl_labels
  248. det_bboxes, keep_idxs = batched_nms(mlvl_bboxes, mlvl_scores,
  249. mlvl_labels, cfg.nms)
  250. det_bboxes = det_bboxes[:cfg.max_per_img]
  251. det_labels = mlvl_labels[keep_idxs][:cfg.max_per_img]
  252. return det_bboxes, det_labels
  253. else:
  254. return mlvl_bboxes, mlvl_scores, mlvl_labels
  255. def forward_train(self,
  256. x,
  257. img_metas,
  258. gt_bboxes,
  259. gt_labels=None,
  260. gt_bboxes_ignore=None,
  261. proposal_cfg=None,
  262. **kwargs):
  263. """
  264. Args:
  265. x (list[Tensor]): Features from FPN.
  266. img_metas (list[dict]): Meta information of each image, e.g.,
  267. image size, scaling factor, etc.
  268. gt_bboxes (Tensor): Ground truth bboxes of the image,
  269. shape (num_gts, 4).
  270. gt_labels (Tensor): Ground truth labels of each box,
  271. shape (num_gts,).
  272. gt_bboxes_ignore (Tensor): Ground truth bboxes to be
  273. ignored, shape (num_ignored_gts, 4).
  274. proposal_cfg (mmcv.Config): Test / postprocessing configuration,
  275. if None, test_cfg would be used
  276. Returns:
  277. tuple:
  278. losses: (dict[str, Tensor]): A dictionary of loss components.
  279. proposal_list (list[Tensor]): Proposals of each image.
  280. """
  281. outs = self(x)
  282. if gt_labels is None:
  283. loss_inputs = outs + (gt_bboxes, img_metas)
  284. else:
  285. loss_inputs = outs + (gt_bboxes, gt_labels, img_metas)
  286. losses = self.loss(*loss_inputs, gt_bboxes_ignore=gt_bboxes_ignore)
  287. if proposal_cfg is None:
  288. return losses
  289. else:
  290. proposal_list = self.get_bboxes(
  291. *outs, img_metas=img_metas, cfg=proposal_cfg)
  292. return losses, proposal_list
  293. def simple_test(self, feats, img_metas, rescale=False):
  294. """Test function without test-time augmentation.
  295. Args:
  296. feats (tuple[torch.Tensor]): Multi-level features from the
  297. upstream network, each is a 4D-tensor.
  298. img_metas (list[dict]): List of image information.
  299. rescale (bool, optional): Whether to rescale the results.
  300. Defaults to False.
  301. Returns:
  302. list[tuple[Tensor, Tensor]]: Each item in result_list is 2-tuple.
  303. The first item is ``bboxes`` with shape (n, 5),
  304. where 5 represent (tl_x, tl_y, br_x, br_y, score).
  305. The shape of the second tensor in the tuple is ``labels``
  306. with shape (n, ).
  307. """
  308. return self.simple_test_bboxes(feats, img_metas, rescale=rescale)
  309. @force_fp32(apply_to=('cls_scores', 'bbox_preds'))
  310. def onnx_export(self,
  311. cls_scores,
  312. bbox_preds,
  313. score_factors=None,
  314. img_metas=None,
  315. with_nms=True):
  316. """Transform network output for a batch into bbox predictions.
  317. Args:
  318. cls_scores (list[Tensor]): Box scores for each scale level
  319. with shape (N, num_points * num_classes, H, W).
  320. bbox_preds (list[Tensor]): Box energies / deltas for each scale
  321. level with shape (N, num_points * 4, H, W).
  322. score_factors (list[Tensor]): score_factors for each s
  323. cale level with shape (N, num_points * 1, H, W).
  324. Default: None.
  325. img_metas (list[dict]): Meta information of each image, e.g.,
  326. image size, scaling factor, etc. Default: None.
  327. with_nms (bool): Whether apply nms to the bboxes. Default: True.
  328. Returns:
  329. tuple[Tensor, Tensor] | list[tuple]: When `with_nms` is True,
  330. it is tuple[Tensor, Tensor], first tensor bboxes with shape
  331. [N, num_det, 5], 5 arrange as (x1, y1, x2, y2, score)
  332. and second element is class labels of shape [N, num_det].
  333. When `with_nms` is False, first tensor is bboxes with
  334. shape [N, num_det, 4], second tensor is raw score has
  335. shape [N, num_det, num_classes].
  336. """
  337. assert len(cls_scores) == len(bbox_preds)
  338. num_levels = len(cls_scores)
  339. featmap_sizes = [featmap.size()[-2:] for featmap in cls_scores]
  340. mlvl_priors = self.prior_generator.grid_priors(
  341. featmap_sizes,
  342. dtype=bbox_preds[0].dtype,
  343. device=bbox_preds[0].device)
  344. mlvl_cls_scores = [cls_scores[i].detach() for i in range(num_levels)]
  345. mlvl_bbox_preds = [bbox_preds[i].detach() for i in range(num_levels)]
  346. assert len(
  347. img_metas
  348. ) == 1, 'Only support one input image while in exporting to ONNX'
  349. img_shape = img_metas[0]['img_shape_for_onnx']
  350. cfg = self.test_cfg
  351. assert len(cls_scores) == len(bbox_preds) == len(mlvl_priors)
  352. device = cls_scores[0].device
  353. batch_size = cls_scores[0].shape[0]
  354. # convert to tensor to keep tracing
  355. nms_pre_tensor = torch.tensor(
  356. cfg.get('nms_pre', -1), device=device, dtype=torch.long)
  357. # e.g. Retina, FreeAnchor, etc.
  358. if score_factors is None:
  359. with_score_factors = False
  360. mlvl_score_factor = [None for _ in range(num_levels)]
  361. else:
  362. # e.g. FCOS, PAA, ATSS, etc.
  363. with_score_factors = True
  364. mlvl_score_factor = [
  365. score_factors[i].detach() for i in range(num_levels)
  366. ]
  367. mlvl_score_factors = []
  368. mlvl_batch_bboxes = []
  369. mlvl_scores = []
  370. for cls_score, bbox_pred, score_factors, priors in zip(
  371. mlvl_cls_scores, mlvl_bbox_preds, mlvl_score_factor,
  372. mlvl_priors):
  373. assert cls_score.size()[-2:] == bbox_pred.size()[-2:]
  374. scores = cls_score.permute(0, 2, 3,
  375. 1).reshape(batch_size, -1,
  376. self.cls_out_channels)
  377. if self.use_sigmoid_cls:
  378. scores = scores.sigmoid()
  379. nms_pre_score = scores
  380. else:
  381. scores = scores.softmax(-1)
  382. nms_pre_score = scores
  383. if with_score_factors:
  384. score_factors = score_factors.permute(0, 2, 3, 1).reshape(
  385. batch_size, -1).sigmoid()
  386. bbox_pred = bbox_pred.permute(0, 2, 3,
  387. 1).reshape(batch_size, -1, 4)
  388. priors = priors.expand(batch_size, -1, priors.size(-1))
  389. # Get top-k predictions
  390. from mmdet.core.export import get_k_for_topk
  391. nms_pre = get_k_for_topk(nms_pre_tensor, bbox_pred.shape[1])
  392. if nms_pre > 0:
  393. if with_score_factors:
  394. nms_pre_score = (nms_pre_score * score_factors[..., None])
  395. else:
  396. nms_pre_score = nms_pre_score
  397. # Get maximum scores for foreground classes.
  398. if self.use_sigmoid_cls:
  399. max_scores, _ = nms_pre_score.max(-1)
  400. else:
  401. # remind that we set FG labels to [0, num_class-1]
  402. # since mmdet v2.0
  403. # BG cat_id: num_class
  404. max_scores, _ = nms_pre_score[..., :-1].max(-1)
  405. _, topk_inds = max_scores.topk(nms_pre)
  406. batch_inds = torch.arange(
  407. batch_size, device=bbox_pred.device).view(
  408. -1, 1).expand_as(topk_inds).long()
  409. # Avoid onnx2tensorrt issue in https://github.com/NVIDIA/TensorRT/issues/1134 # noqa: E501
  410. transformed_inds = bbox_pred.shape[1] * batch_inds + topk_inds
  411. priors = priors.reshape(
  412. -1, priors.size(-1))[transformed_inds, :].reshape(
  413. batch_size, -1, priors.size(-1))
  414. bbox_pred = bbox_pred.reshape(-1,
  415. 4)[transformed_inds, :].reshape(
  416. batch_size, -1, 4)
  417. scores = scores.reshape(
  418. -1, self.cls_out_channels)[transformed_inds, :].reshape(
  419. batch_size, -1, self.cls_out_channels)
  420. if with_score_factors:
  421. score_factors = score_factors.reshape(
  422. -1, 1)[transformed_inds].reshape(batch_size, -1)
  423. bboxes = self.bbox_coder.decode(
  424. priors, bbox_pred, max_shape=img_shape)
  425. mlvl_batch_bboxes.append(bboxes)
  426. mlvl_scores.append(scores)
  427. if with_score_factors:
  428. mlvl_score_factors.append(score_factors)
  429. batch_bboxes = torch.cat(mlvl_batch_bboxes, dim=1)
  430. batch_scores = torch.cat(mlvl_scores, dim=1)
  431. if with_score_factors:
  432. batch_score_factors = torch.cat(mlvl_score_factors, dim=1)
  433. # Replace multiclass_nms with ONNX::NonMaxSuppression in deployment
  434. from mmdet.core.export import add_dummy_nms_for_onnx
  435. if not self.use_sigmoid_cls:
  436. batch_scores = batch_scores[..., :self.num_classes]
  437. if with_score_factors:
  438. batch_scores = batch_scores * (batch_score_factors.unsqueeze(2))
  439. if with_nms:
  440. max_output_boxes_per_class = cfg.nms.get(
  441. 'max_output_boxes_per_class', 200)
  442. iou_threshold = cfg.nms.get('iou_threshold', 0.5)
  443. score_threshold = cfg.score_thr
  444. nms_pre = cfg.get('deploy_nms_pre', -1)
  445. return add_dummy_nms_for_onnx(batch_bboxes, batch_scores,
  446. max_output_boxes_per_class,
  447. iou_threshold, score_threshold,
  448. nms_pre, cfg.max_per_img)
  449. else:
  450. return batch_bboxes, batch_scores

No Description

Contributors (3)