You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

fsaf_head.py 19 kB

2 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433
  1. # Copyright (c) OpenMMLab. All rights reserved.
  2. import numpy as np
  3. import torch
  4. from mmcv.runner import force_fp32
  5. from mmdet.core import (anchor_inside_flags, images_to_levels, multi_apply,
  6. unmap)
  7. from ..builder import HEADS
  8. from ..losses.accuracy import accuracy
  9. from ..losses.utils import weight_reduce_loss
  10. from .retina_head import RetinaHead
  11. @HEADS.register_module()
  12. class FSAFHead(RetinaHead):
  13. """Anchor-free head used in `FSAF <https://arxiv.org/abs/1903.00621>`_.
  14. The head contains two subnetworks. The first classifies anchor boxes and
  15. the second regresses deltas for the anchors (num_anchors is 1 for anchor-
  16. free methods)
  17. Args:
  18. *args: Same as its base class in :class:`RetinaHead`
  19. score_threshold (float, optional): The score_threshold to calculate
  20. positive recall. If given, prediction scores lower than this value
  21. is counted as incorrect prediction. Default to None.
  22. init_cfg (dict or list[dict], optional): Initialization config dict.
  23. Default: None
  24. **kwargs: Same as its base class in :class:`RetinaHead`
  25. Example:
  26. >>> import torch
  27. >>> self = FSAFHead(11, 7)
  28. >>> x = torch.rand(1, 7, 32, 32)
  29. >>> cls_score, bbox_pred = self.forward_single(x)
  30. >>> # Each anchor predicts a score for each class except background
  31. >>> cls_per_anchor = cls_score.shape[1] / self.num_anchors
  32. >>> box_per_anchor = bbox_pred.shape[1] / self.num_anchors
  33. >>> assert cls_per_anchor == self.num_classes
  34. >>> assert box_per_anchor == 4
  35. """
  36. def __init__(self, *args, score_threshold=None, init_cfg=None, **kwargs):
  37. # The positive bias in self.retina_reg conv is to prevent predicted \
  38. # bbox with 0 area
  39. if init_cfg is None:
  40. init_cfg = dict(
  41. type='Normal',
  42. layer='Conv2d',
  43. std=0.01,
  44. override=[
  45. dict(
  46. type='Normal',
  47. name='retina_cls',
  48. std=0.01,
  49. bias_prob=0.01),
  50. dict(
  51. type='Normal', name='retina_reg', std=0.01, bias=0.25)
  52. ])
  53. super().__init__(*args, init_cfg=init_cfg, **kwargs)
  54. self.score_threshold = score_threshold
  55. def forward_single(self, x):
  56. """Forward feature map of a single scale level.
  57. Args:
  58. x (Tensor): Feature map of a single scale level.
  59. Returns:
  60. tuple (Tensor):
  61. cls_score (Tensor): Box scores for each scale level
  62. Has shape (N, num_points * num_classes, H, W).
  63. bbox_pred (Tensor): Box energies / deltas for each scale
  64. level with shape (N, num_points * 4, H, W).
  65. """
  66. cls_score, bbox_pred = super().forward_single(x)
  67. # relu: TBLR encoder only accepts positive bbox_pred
  68. return cls_score, self.relu(bbox_pred)
  69. def _get_targets_single(self,
  70. flat_anchors,
  71. valid_flags,
  72. gt_bboxes,
  73. gt_bboxes_ignore,
  74. gt_labels,
  75. img_meta,
  76. label_channels=1,
  77. unmap_outputs=True):
  78. """Compute regression and classification targets for anchors in a
  79. single image.
  80. Most of the codes are the same with the base class
  81. :obj: `AnchorHead`, except that it also collects and returns
  82. the matched gt index in the image (from 0 to num_gt-1). If the
  83. anchor bbox is not matched to any gt, the corresponding value in
  84. pos_gt_inds is -1.
  85. """
  86. inside_flags = anchor_inside_flags(flat_anchors, valid_flags,
  87. img_meta['img_shape'][:2],
  88. self.train_cfg.allowed_border)
  89. if not inside_flags.any():
  90. return (None, ) * 7
  91. # Assign gt and sample anchors
  92. anchors = flat_anchors[inside_flags.type(torch.bool), :]
  93. assign_result = self.assigner.assign(
  94. anchors, gt_bboxes, gt_bboxes_ignore,
  95. None if self.sampling else gt_labels)
  96. sampling_result = self.sampler.sample(assign_result, anchors,
  97. gt_bboxes)
  98. num_valid_anchors = anchors.shape[0]
  99. bbox_targets = torch.zeros_like(anchors)
  100. bbox_weights = torch.zeros_like(anchors)
  101. labels = anchors.new_full((num_valid_anchors, ),
  102. self.num_classes,
  103. dtype=torch.long)
  104. label_weights = anchors.new_zeros((num_valid_anchors, label_channels),
  105. dtype=torch.float)
  106. pos_gt_inds = anchors.new_full((num_valid_anchors, ),
  107. -1,
  108. dtype=torch.long)
  109. pos_inds = sampling_result.pos_inds
  110. neg_inds = sampling_result.neg_inds
  111. if len(pos_inds) > 0:
  112. if not self.reg_decoded_bbox:
  113. pos_bbox_targets = self.bbox_coder.encode(
  114. sampling_result.pos_bboxes, sampling_result.pos_gt_bboxes)
  115. else:
  116. # When the regression loss (e.g. `IouLoss`, `GIouLoss`)
  117. # is applied directly on the decoded bounding boxes, both
  118. # the predicted boxes and regression targets should be with
  119. # absolute coordinate format.
  120. pos_bbox_targets = sampling_result.pos_gt_bboxes
  121. bbox_targets[pos_inds, :] = pos_bbox_targets
  122. bbox_weights[pos_inds, :] = 1.0
  123. # The assigned gt_index for each anchor. (0-based)
  124. pos_gt_inds[pos_inds] = sampling_result.pos_assigned_gt_inds
  125. if gt_labels is None:
  126. # Only rpn gives gt_labels as None
  127. # Foreground is the first class
  128. labels[pos_inds] = 0
  129. else:
  130. labels[pos_inds] = gt_labels[
  131. sampling_result.pos_assigned_gt_inds]
  132. if self.train_cfg.pos_weight <= 0:
  133. label_weights[pos_inds] = 1.0
  134. else:
  135. label_weights[pos_inds] = self.train_cfg.pos_weight
  136. if len(neg_inds) > 0:
  137. label_weights[neg_inds] = 1.0
  138. # shadowed_labels is a tensor composed of tuples
  139. # (anchor_inds, class_label) that indicate those anchors lying in the
  140. # outer region of a gt or overlapped by another gt with a smaller
  141. # area.
  142. #
  143. # Therefore, only the shadowed labels are ignored for loss calculation.
  144. # the key `shadowed_labels` is defined in :obj:`CenterRegionAssigner`
  145. shadowed_labels = assign_result.get_extra_property('shadowed_labels')
  146. if shadowed_labels is not None and shadowed_labels.numel():
  147. if len(shadowed_labels.shape) == 2:
  148. idx_, label_ = shadowed_labels[:, 0], shadowed_labels[:, 1]
  149. assert (labels[idx_] != label_).all(), \
  150. 'One label cannot be both positive and ignored'
  151. label_weights[idx_, label_] = 0
  152. else:
  153. label_weights[shadowed_labels] = 0
  154. # map up to original set of anchors
  155. if unmap_outputs:
  156. num_total_anchors = flat_anchors.size(0)
  157. labels = unmap(labels, num_total_anchors, inside_flags)
  158. label_weights = unmap(label_weights, num_total_anchors,
  159. inside_flags)
  160. bbox_targets = unmap(bbox_targets, num_total_anchors, inside_flags)
  161. bbox_weights = unmap(bbox_weights, num_total_anchors, inside_flags)
  162. pos_gt_inds = unmap(
  163. pos_gt_inds, num_total_anchors, inside_flags, fill=-1)
  164. return (labels, label_weights, bbox_targets, bbox_weights, pos_inds,
  165. neg_inds, sampling_result, pos_gt_inds)
  166. @force_fp32(apply_to=('cls_scores', 'bbox_preds'))
  167. def loss(self,
  168. cls_scores,
  169. bbox_preds,
  170. gt_bboxes,
  171. gt_labels,
  172. img_metas,
  173. gt_bboxes_ignore=None):
  174. """Compute loss of the head.
  175. Args:
  176. cls_scores (list[Tensor]): Box scores for each scale level
  177. Has shape (N, num_points * num_classes, H, W).
  178. bbox_preds (list[Tensor]): Box energies / deltas for each scale
  179. level with shape (N, num_points * 4, H, W).
  180. gt_bboxes (list[Tensor]): each item are the truth boxes for each
  181. image in [tl_x, tl_y, br_x, br_y] format.
  182. gt_labels (list[Tensor]): class indices corresponding to each box
  183. img_metas (list[dict]): Meta information of each image, e.g.,
  184. image size, scaling factor, etc.
  185. gt_bboxes_ignore (None | list[Tensor]): specify which bounding
  186. boxes can be ignored when computing the loss.
  187. Returns:
  188. dict[str, Tensor]: A dictionary of loss components.
  189. """
  190. for i in range(len(bbox_preds)): # loop over fpn level
  191. # avoid 0 area of the predicted bbox
  192. bbox_preds[i] = bbox_preds[i].clamp(min=1e-4)
  193. # TODO: It may directly use the base-class loss function.
  194. featmap_sizes = [featmap.size()[-2:] for featmap in cls_scores]
  195. assert len(featmap_sizes) == self.prior_generator.num_levels
  196. batch_size = len(gt_bboxes)
  197. device = cls_scores[0].device
  198. anchor_list, valid_flag_list = self.get_anchors(
  199. featmap_sizes, img_metas, device=device)
  200. label_channels = self.cls_out_channels if self.use_sigmoid_cls else 1
  201. cls_reg_targets = self.get_targets(
  202. anchor_list,
  203. valid_flag_list,
  204. gt_bboxes,
  205. img_metas,
  206. gt_bboxes_ignore_list=gt_bboxes_ignore,
  207. gt_labels_list=gt_labels,
  208. label_channels=label_channels)
  209. if cls_reg_targets is None:
  210. return None
  211. (labels_list, label_weights_list, bbox_targets_list, bbox_weights_list,
  212. num_total_pos, num_total_neg,
  213. pos_assigned_gt_inds_list) = cls_reg_targets
  214. num_gts = np.array(list(map(len, gt_labels)))
  215. num_total_samples = (
  216. num_total_pos + num_total_neg if self.sampling else num_total_pos)
  217. # anchor number of multi levels
  218. num_level_anchors = [anchors.size(0) for anchors in anchor_list[0]]
  219. # concat all level anchors and flags to a single tensor
  220. concat_anchor_list = []
  221. for i in range(len(anchor_list)):
  222. concat_anchor_list.append(torch.cat(anchor_list[i]))
  223. all_anchor_list = images_to_levels(concat_anchor_list,
  224. num_level_anchors)
  225. losses_cls, losses_bbox = multi_apply(
  226. self.loss_single,
  227. cls_scores,
  228. bbox_preds,
  229. all_anchor_list,
  230. labels_list,
  231. label_weights_list,
  232. bbox_targets_list,
  233. bbox_weights_list,
  234. num_total_samples=num_total_samples)
  235. # `pos_assigned_gt_inds_list` (length: fpn_levels) stores the assigned
  236. # gt index of each anchor bbox in each fpn level.
  237. cum_num_gts = list(np.cumsum(num_gts)) # length of batch_size
  238. for i, assign in enumerate(pos_assigned_gt_inds_list):
  239. # loop over fpn levels
  240. for j in range(1, batch_size):
  241. # loop over batch size
  242. # Convert gt indices in each img to those in the batch
  243. assign[j][assign[j] >= 0] += int(cum_num_gts[j - 1])
  244. pos_assigned_gt_inds_list[i] = assign.flatten()
  245. labels_list[i] = labels_list[i].flatten()
  246. num_gts = sum(map(len, gt_labels)) # total number of gt in the batch
  247. # The unique label index of each gt in the batch
  248. label_sequence = torch.arange(num_gts, device=device)
  249. # Collect the average loss of each gt in each level
  250. with torch.no_grad():
  251. loss_levels, = multi_apply(
  252. self.collect_loss_level_single,
  253. losses_cls,
  254. losses_bbox,
  255. pos_assigned_gt_inds_list,
  256. labels_seq=label_sequence)
  257. # Shape: (fpn_levels, num_gts). Loss of each gt at each fpn level
  258. loss_levels = torch.stack(loss_levels, dim=0)
  259. # Locate the best fpn level for loss back-propagation
  260. if loss_levels.numel() == 0: # zero gt
  261. argmin = loss_levels.new_empty((num_gts, ), dtype=torch.long)
  262. else:
  263. _, argmin = loss_levels.min(dim=0)
  264. # Reweight the loss of each (anchor, label) pair, so that only those
  265. # at the best gt level are back-propagated.
  266. losses_cls, losses_bbox, pos_inds = multi_apply(
  267. self.reweight_loss_single,
  268. losses_cls,
  269. losses_bbox,
  270. pos_assigned_gt_inds_list,
  271. labels_list,
  272. list(range(len(losses_cls))),
  273. min_levels=argmin)
  274. num_pos = torch.cat(pos_inds, 0).sum().float()
  275. pos_recall = self.calculate_pos_recall(cls_scores, labels_list,
  276. pos_inds)
  277. if num_pos == 0: # No gt
  278. avg_factor = num_pos + float(num_total_neg)
  279. else:
  280. avg_factor = num_pos
  281. for i in range(len(losses_cls)):
  282. losses_cls[i] /= avg_factor
  283. losses_bbox[i] /= avg_factor
  284. return dict(
  285. loss_cls=losses_cls,
  286. loss_bbox=losses_bbox,
  287. num_pos=num_pos / batch_size,
  288. pos_recall=pos_recall)
  289. def calculate_pos_recall(self, cls_scores, labels_list, pos_inds):
  290. """Calculate positive recall with score threshold.
  291. Args:
  292. cls_scores (list[Tensor]): Classification scores at all fpn levels.
  293. Each tensor is in shape (N, num_classes * num_anchors, H, W)
  294. labels_list (list[Tensor]): The label that each anchor is assigned
  295. to. Shape (N * H * W * num_anchors, )
  296. pos_inds (list[Tensor]): List of bool tensors indicating whether
  297. the anchor is assigned to a positive label.
  298. Shape (N * H * W * num_anchors, )
  299. Returns:
  300. Tensor: A single float number indicating the positive recall.
  301. """
  302. with torch.no_grad():
  303. num_class = self.num_classes
  304. scores = [
  305. cls.permute(0, 2, 3, 1).reshape(-1, num_class)[pos]
  306. for cls, pos in zip(cls_scores, pos_inds)
  307. ]
  308. labels = [
  309. label.reshape(-1)[pos]
  310. for label, pos in zip(labels_list, pos_inds)
  311. ]
  312. scores = torch.cat(scores, dim=0)
  313. labels = torch.cat(labels, dim=0)
  314. if self.use_sigmoid_cls:
  315. scores = scores.sigmoid()
  316. else:
  317. scores = scores.softmax(dim=1)
  318. return accuracy(scores, labels, thresh=self.score_threshold)
  319. def collect_loss_level_single(self, cls_loss, reg_loss, assigned_gt_inds,
  320. labels_seq):
  321. """Get the average loss in each FPN level w.r.t. each gt label.
  322. Args:
  323. cls_loss (Tensor): Classification loss of each feature map pixel,
  324. shape (num_anchor, num_class)
  325. reg_loss (Tensor): Regression loss of each feature map pixel,
  326. shape (num_anchor, 4)
  327. assigned_gt_inds (Tensor): It indicates which gt the prior is
  328. assigned to (0-based, -1: no assignment). shape (num_anchor),
  329. labels_seq: The rank of labels. shape (num_gt)
  330. Returns:
  331. shape: (num_gt), average loss of each gt in this level
  332. """
  333. if len(reg_loss.shape) == 2: # iou loss has shape (num_prior, 4)
  334. reg_loss = reg_loss.sum(dim=-1) # sum loss in tblr dims
  335. if len(cls_loss.shape) == 2:
  336. cls_loss = cls_loss.sum(dim=-1) # sum loss in class dims
  337. loss = cls_loss + reg_loss
  338. assert loss.size(0) == assigned_gt_inds.size(0)
  339. # Default loss value is 1e6 for a layer where no anchor is positive
  340. # to ensure it will not be chosen to back-propagate gradient
  341. losses_ = loss.new_full(labels_seq.shape, 1e6)
  342. for i, l in enumerate(labels_seq):
  343. match = assigned_gt_inds == l
  344. if match.any():
  345. losses_[i] = loss[match].mean()
  346. return losses_,
  347. def reweight_loss_single(self, cls_loss, reg_loss, assigned_gt_inds,
  348. labels, level, min_levels):
  349. """Reweight loss values at each level.
  350. Reassign loss values at each level by masking those where the
  351. pre-calculated loss is too large. Then return the reduced losses.
  352. Args:
  353. cls_loss (Tensor): Element-wise classification loss.
  354. Shape: (num_anchors, num_classes)
  355. reg_loss (Tensor): Element-wise regression loss.
  356. Shape: (num_anchors, 4)
  357. assigned_gt_inds (Tensor): The gt indices that each anchor bbox
  358. is assigned to. -1 denotes a negative anchor, otherwise it is the
  359. gt index (0-based). Shape: (num_anchors, ),
  360. labels (Tensor): Label assigned to anchors. Shape: (num_anchors, ).
  361. level (int): The current level index in the pyramid
  362. (0-4 for RetinaNet)
  363. min_levels (Tensor): The best-matching level for each gt.
  364. Shape: (num_gts, ),
  365. Returns:
  366. tuple:
  367. - cls_loss: Reduced corrected classification loss. Scalar.
  368. - reg_loss: Reduced corrected regression loss. Scalar.
  369. - pos_flags (Tensor): Corrected bool tensor indicating the
  370. final positive anchors. Shape: (num_anchors, ).
  371. """
  372. loc_weight = torch.ones_like(reg_loss)
  373. cls_weight = torch.ones_like(cls_loss)
  374. pos_flags = assigned_gt_inds >= 0 # positive pixel flag
  375. pos_indices = torch.nonzero(pos_flags, as_tuple=False).flatten()
  376. if pos_flags.any(): # pos pixels exist
  377. pos_assigned_gt_inds = assigned_gt_inds[pos_flags]
  378. zeroing_indices = (min_levels[pos_assigned_gt_inds] != level)
  379. neg_indices = pos_indices[zeroing_indices]
  380. if neg_indices.numel():
  381. pos_flags[neg_indices] = 0
  382. loc_weight[neg_indices] = 0
  383. # Only the weight corresponding to the label is
  384. # zeroed out if not selected
  385. zeroing_labels = labels[neg_indices]
  386. assert (zeroing_labels >= 0).all()
  387. cls_weight[neg_indices, zeroing_labels] = 0
  388. # Weighted loss for both cls and reg loss
  389. cls_loss = weight_reduce_loss(cls_loss, cls_weight, reduction='sum')
  390. reg_loss = weight_reduce_loss(reg_loss, loc_weight, reduction='sum')
  391. return cls_loss, reg_loss, pos_flags

No Description

Contributors (3)