You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

dynamic_roi_head.py 6.7 kB

2 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155
  1. # Copyright (c) OpenMMLab. All rights reserved.
  2. import numpy as np
  3. import torch
  4. from mmdet.core import bbox2roi
  5. from mmdet.models.losses import SmoothL1Loss
  6. from ..builder import HEADS
  7. from .standard_roi_head import StandardRoIHead
  8. EPS = 1e-15
  9. @HEADS.register_module()
  10. class DynamicRoIHead(StandardRoIHead):
  11. """RoI head for `Dynamic R-CNN <https://arxiv.org/abs/2004.06002>`_."""
  12. def __init__(self, **kwargs):
  13. super(DynamicRoIHead, self).__init__(**kwargs)
  14. assert isinstance(self.bbox_head.loss_bbox, SmoothL1Loss)
  15. # the IoU history of the past `update_iter_interval` iterations
  16. self.iou_history = []
  17. # the beta history of the past `update_iter_interval` iterations
  18. self.beta_history = []
  19. def forward_train(self,
  20. x,
  21. img_metas,
  22. proposal_list,
  23. gt_bboxes,
  24. gt_labels,
  25. gt_bboxes_ignore=None,
  26. gt_masks=None):
  27. """Forward function for training.
  28. Args:
  29. x (list[Tensor]): list of multi-level img features.
  30. img_metas (list[dict]): list of image info dict where each dict
  31. has: 'img_shape', 'scale_factor', 'flip', and may also contain
  32. 'filename', 'ori_shape', 'pad_shape', and 'img_norm_cfg'.
  33. For details on the values of these keys see
  34. `mmdet/datasets/pipelines/formatting.py:Collect`.
  35. proposals (list[Tensors]): list of region proposals.
  36. gt_bboxes (list[Tensor]): each item are the truth boxes for each
  37. image in [tl_x, tl_y, br_x, br_y] format.
  38. gt_labels (list[Tensor]): class indices corresponding to each box
  39. gt_bboxes_ignore (None | list[Tensor]): specify which bounding
  40. boxes can be ignored when computing the loss.
  41. gt_masks (None | Tensor) : true segmentation masks for each box
  42. used if the architecture supports a segmentation task.
  43. Returns:
  44. dict[str, Tensor]: a dictionary of loss components
  45. """
  46. # assign gts and sample proposals
  47. if self.with_bbox or self.with_mask:
  48. num_imgs = len(img_metas)
  49. if gt_bboxes_ignore is None:
  50. gt_bboxes_ignore = [None for _ in range(num_imgs)]
  51. sampling_results = []
  52. cur_iou = []
  53. for i in range(num_imgs):
  54. assign_result = self.bbox_assigner.assign(
  55. proposal_list[i], gt_bboxes[i], gt_bboxes_ignore[i],
  56. gt_labels[i])
  57. sampling_result = self.bbox_sampler.sample(
  58. assign_result,
  59. proposal_list[i],
  60. gt_bboxes[i],
  61. gt_labels[i],
  62. feats=[lvl_feat[i][None] for lvl_feat in x])
  63. # record the `iou_topk`-th largest IoU in an image
  64. iou_topk = min(self.train_cfg.dynamic_rcnn.iou_topk,
  65. len(assign_result.max_overlaps))
  66. ious, _ = torch.topk(assign_result.max_overlaps, iou_topk)
  67. cur_iou.append(ious[-1].item())
  68. sampling_results.append(sampling_result)
  69. # average the current IoUs over images
  70. cur_iou = np.mean(cur_iou)
  71. self.iou_history.append(cur_iou)
  72. losses = dict()
  73. # bbox head forward and loss
  74. if self.with_bbox:
  75. bbox_results = self._bbox_forward_train(x, sampling_results,
  76. gt_bboxes, gt_labels,
  77. img_metas)
  78. losses.update(bbox_results['loss_bbox'])
  79. # mask head forward and loss
  80. if self.with_mask:
  81. mask_results = self._mask_forward_train(x, sampling_results,
  82. bbox_results['bbox_feats'],
  83. gt_masks, img_metas)
  84. losses.update(mask_results['loss_mask'])
  85. # update IoU threshold and SmoothL1 beta
  86. update_iter_interval = self.train_cfg.dynamic_rcnn.update_iter_interval
  87. if len(self.iou_history) % update_iter_interval == 0:
  88. new_iou_thr, new_beta = self.update_hyperparameters()
  89. return losses
  90. def _bbox_forward_train(self, x, sampling_results, gt_bboxes, gt_labels,
  91. img_metas):
  92. num_imgs = len(img_metas)
  93. rois = bbox2roi([res.bboxes for res in sampling_results])
  94. bbox_results = self._bbox_forward(x, rois)
  95. bbox_targets = self.bbox_head.get_targets(sampling_results, gt_bboxes,
  96. gt_labels, self.train_cfg)
  97. # record the `beta_topk`-th smallest target
  98. # `bbox_targets[2]` and `bbox_targets[3]` stand for bbox_targets
  99. # and bbox_weights, respectively
  100. pos_inds = bbox_targets[3][:, 0].nonzero().squeeze(1)
  101. num_pos = len(pos_inds)
  102. cur_target = bbox_targets[2][pos_inds, :2].abs().mean(dim=1)
  103. beta_topk = min(self.train_cfg.dynamic_rcnn.beta_topk * num_imgs,
  104. num_pos)
  105. cur_target = torch.kthvalue(cur_target, beta_topk)[0].item()
  106. self.beta_history.append(cur_target)
  107. loss_bbox = self.bbox_head.loss(bbox_results['cls_score'],
  108. bbox_results['bbox_pred'], rois,
  109. *bbox_targets)
  110. bbox_results.update(loss_bbox=loss_bbox)
  111. return bbox_results
  112. def update_hyperparameters(self):
  113. """Update hyperparameters like IoU thresholds for assigner and beta for
  114. SmoothL1 loss based on the training statistics.
  115. Returns:
  116. tuple[float]: the updated ``iou_thr`` and ``beta``.
  117. """
  118. new_iou_thr = max(self.train_cfg.dynamic_rcnn.initial_iou,
  119. np.mean(self.iou_history))
  120. self.iou_history = []
  121. self.bbox_assigner.pos_iou_thr = new_iou_thr
  122. self.bbox_assigner.neg_iou_thr = new_iou_thr
  123. self.bbox_assigner.min_pos_iou = new_iou_thr
  124. if (np.median(self.beta_history) < EPS):
  125. # avoid 0 or too small value for new_beta
  126. new_beta = self.bbox_head.loss_bbox.beta
  127. else:
  128. new_beta = min(self.train_cfg.dynamic_rcnn.initial_beta,
  129. np.median(self.beta_history))
  130. self.beta_history = []
  131. self.bbox_head.loss_bbox.beta = new_beta
  132. return new_iou_thr, new_beta

No Description

Contributors (2)