You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

anchor_head.py 25 kB

2 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542
  1. # Copyright (c) OpenMMLab. All rights reserved.
  2. import warnings
  3. import torch
  4. import torch.nn as nn
  5. from mmcv.runner import force_fp32
  6. from mmdet.core import (anchor_inside_flags, build_assigner, build_bbox_coder,
  7. build_prior_generator, build_sampler, images_to_levels,
  8. multi_apply, unmap)
  9. from ..builder import HEADS, build_loss
  10. from .base_dense_head import BaseDenseHead
  11. from .dense_test_mixins import BBoxTestMixin
  12. @HEADS.register_module()
  13. class AnchorHead(BaseDenseHead, BBoxTestMixin):
  14. """Anchor-based head (RPN, RetinaNet, SSD, etc.).
  15. Args:
  16. num_classes (int): Number of categories excluding the background
  17. category.
  18. in_channels (int): Number of channels in the input feature map.
  19. feat_channels (int): Number of hidden channels. Used in child classes.
  20. anchor_generator (dict): Config dict for anchor generator
  21. bbox_coder (dict): Config of bounding box coder.
  22. reg_decoded_bbox (bool): If true, the regression loss would be
  23. applied directly on decoded bounding boxes, converting both
  24. the predicted boxes and regression targets to absolute
  25. coordinates format. Default False. It should be `True` when
  26. using `IoULoss`, `GIoULoss`, or `DIoULoss` in the bbox head.
  27. loss_cls (dict): Config of classification loss.
  28. loss_bbox (dict): Config of localization loss.
  29. train_cfg (dict): Training config of anchor head.
  30. test_cfg (dict): Testing config of anchor head.
  31. init_cfg (dict or list[dict], optional): Initialization config dict.
  32. """ # noqa: W605
  33. def __init__(self,
  34. num_classes,
  35. in_channels,
  36. feat_channels=256,
  37. anchor_generator=dict(
  38. type='AnchorGenerator',
  39. scales=[8, 16, 32],
  40. ratios=[0.5, 1.0, 2.0],
  41. strides=[4, 8, 16, 32, 64]),
  42. bbox_coder=dict(
  43. type='DeltaXYWHBBoxCoder',
  44. clip_border=True,
  45. target_means=(.0, .0, .0, .0),
  46. target_stds=(1.0, 1.0, 1.0, 1.0)),
  47. reg_decoded_bbox=False,
  48. loss_cls=dict(
  49. type='CrossEntropyLoss',
  50. use_sigmoid=True,
  51. loss_weight=1.0),
  52. loss_bbox=dict(
  53. type='SmoothL1Loss', beta=1.0 / 9.0, loss_weight=1.0),
  54. train_cfg=None,
  55. test_cfg=None,
  56. init_cfg=dict(type='Normal', layer='Conv2d', std=0.01)):
  57. super(AnchorHead, self).__init__(init_cfg)
  58. self.in_channels = in_channels
  59. self.num_classes = num_classes
  60. self.feat_channels = feat_channels
  61. self.use_sigmoid_cls = loss_cls.get('use_sigmoid', False)
  62. if self.use_sigmoid_cls:
  63. self.cls_out_channels = num_classes
  64. else:
  65. self.cls_out_channels = num_classes + 1
  66. if self.cls_out_channels <= 0:
  67. raise ValueError(f'num_classes={num_classes} is too small')
  68. self.reg_decoded_bbox = reg_decoded_bbox
  69. self.bbox_coder = build_bbox_coder(bbox_coder)
  70. self.loss_cls = build_loss(loss_cls)
  71. self.loss_bbox = build_loss(loss_bbox)
  72. self.train_cfg = train_cfg
  73. self.test_cfg = test_cfg
  74. if self.train_cfg:
  75. self.assigner = build_assigner(self.train_cfg.assigner)
  76. if hasattr(self.train_cfg,
  77. 'sampler') and self.train_cfg.sampler.type.split(
  78. '.')[-1] != 'PseudoSampler':
  79. self.sampling = True
  80. sampler_cfg = self.train_cfg.sampler
  81. # avoid BC-breaking
  82. if loss_cls['type'] in [
  83. 'FocalLoss', 'GHMC', 'QualityFocalLoss'
  84. ]:
  85. warnings.warn(
  86. 'DeprecationWarning: Determining whether to sampling'
  87. 'by loss type is deprecated, please delete sampler in'
  88. 'your config when using `FocalLoss`, `GHMC`, '
  89. '`QualityFocalLoss` or other FocalLoss variant.')
  90. self.sampling = False
  91. sampler_cfg = dict(type='PseudoSampler')
  92. else:
  93. self.sampling = False
  94. sampler_cfg = dict(type='PseudoSampler')
  95. self.sampler = build_sampler(sampler_cfg, context=self)
  96. self.fp16_enabled = False
  97. self.prior_generator = build_prior_generator(anchor_generator)
  98. # Usually the numbers of anchors for each level are the same
  99. # except SSD detectors. So it is an int in the most dense
  100. # heads but a list of int in SSDHead
  101. self.num_base_priors = self.prior_generator.num_base_priors[0]
  102. self._init_layers()
  103. @property
  104. def num_anchors(self):
  105. warnings.warn('DeprecationWarning: `num_anchors` is deprecated, '
  106. 'for consistency or also use '
  107. '`num_base_priors` instead')
  108. return self.prior_generator.num_base_priors[0]
  109. @property
  110. def anchor_generator(self):
  111. warnings.warn('DeprecationWarning: anchor_generator is deprecated, '
  112. 'please use "prior_generator" instead')
  113. return self.prior_generator
  114. def _init_layers(self):
  115. """Initialize layers of the head."""
  116. self.conv_cls = nn.Conv2d(self.in_channels,
  117. self.num_base_priors * self.cls_out_channels,
  118. 1)
  119. self.conv_reg = nn.Conv2d(self.in_channels, self.num_base_priors * 4,
  120. 1)
  121. def forward_single(self, x):
  122. """Forward feature of a single scale level.
  123. Args:
  124. x (Tensor): Features of a single scale level.
  125. Returns:
  126. tuple:
  127. cls_score (Tensor): Cls scores for a single scale level \
  128. the channels number is num_base_priors * num_classes.
  129. bbox_pred (Tensor): Box energies / deltas for a single scale \
  130. level, the channels number is num_base_priors * 4.
  131. """
  132. cls_score = self.conv_cls(x)
  133. bbox_pred = self.conv_reg(x)
  134. return cls_score, bbox_pred
  135. def forward(self, feats):
  136. """Forward features from the upstream network.
  137. Args:
  138. feats (tuple[Tensor]): Features from the upstream network, each is
  139. a 4D-tensor.
  140. Returns:
  141. tuple: A tuple of classification scores and bbox prediction.
  142. - cls_scores (list[Tensor]): Classification scores for all \
  143. scale levels, each is a 4D-tensor, the channels number \
  144. is num_base_priors * num_classes.
  145. - bbox_preds (list[Tensor]): Box energies / deltas for all \
  146. scale levels, each is a 4D-tensor, the channels number \
  147. is num_base_priors * 4.
  148. """
  149. return multi_apply(self.forward_single, feats)
  150. def get_anchors(self, featmap_sizes, img_metas, device='cuda'):
  151. """Get anchors according to feature map sizes.
  152. Args:
  153. featmap_sizes (list[tuple]): Multi-level feature map sizes.
  154. img_metas (list[dict]): Image meta info.
  155. device (torch.device | str): Device for returned tensors
  156. Returns:
  157. tuple:
  158. anchor_list (list[Tensor]): Anchors of each image.
  159. valid_flag_list (list[Tensor]): Valid flags of each image.
  160. """
  161. num_imgs = len(img_metas)
  162. # since feature map sizes of all images are the same, we only compute
  163. # anchors for one time
  164. multi_level_anchors = self.prior_generator.grid_priors(
  165. featmap_sizes, device=device)
  166. anchor_list = [multi_level_anchors for _ in range(num_imgs)]
  167. # for each image, we compute valid flags of multi level anchors
  168. valid_flag_list = []
  169. for img_id, img_meta in enumerate(img_metas):
  170. multi_level_flags = self.prior_generator.valid_flags(
  171. featmap_sizes, img_meta['pad_shape'], device)
  172. valid_flag_list.append(multi_level_flags)
  173. return anchor_list, valid_flag_list
  174. def _get_targets_single(self,
  175. flat_anchors,
  176. valid_flags,
  177. gt_bboxes,
  178. gt_bboxes_ignore,
  179. gt_labels,
  180. img_meta,
  181. label_channels=1,
  182. unmap_outputs=True):
  183. """Compute regression and classification targets for anchors in a
  184. single image.
  185. Args:
  186. flat_anchors (Tensor): Multi-level anchors of the image, which are
  187. concatenated into a single tensor of shape (num_anchors ,4)
  188. valid_flags (Tensor): Multi level valid flags of the image,
  189. which are concatenated into a single tensor of
  190. shape (num_anchors,).
  191. gt_bboxes (Tensor): Ground truth bboxes of the image,
  192. shape (num_gts, 4).
  193. gt_bboxes_ignore (Tensor): Ground truth bboxes to be
  194. ignored, shape (num_ignored_gts, 4).
  195. img_meta (dict): Meta info of the image.
  196. gt_labels (Tensor): Ground truth labels of each box,
  197. shape (num_gts,).
  198. label_channels (int): Channel of label.
  199. unmap_outputs (bool): Whether to map outputs back to the original
  200. set of anchors.
  201. Returns:
  202. tuple:
  203. labels_list (list[Tensor]): Labels of each level
  204. label_weights_list (list[Tensor]): Label weights of each level
  205. bbox_targets_list (list[Tensor]): BBox targets of each level
  206. bbox_weights_list (list[Tensor]): BBox weights of each level
  207. num_total_pos (int): Number of positive samples in all images
  208. num_total_neg (int): Number of negative samples in all images
  209. """
  210. inside_flags = anchor_inside_flags(flat_anchors, valid_flags,
  211. img_meta['img_shape'][:2],
  212. self.train_cfg.allowed_border)
  213. if not inside_flags.any():
  214. return (None, ) * 7
  215. # assign gt and sample anchors
  216. anchors = flat_anchors[inside_flags, :]
  217. assign_result = self.assigner.assign(
  218. anchors, gt_bboxes, gt_bboxes_ignore,
  219. None if self.sampling else gt_labels)
  220. sampling_result = self.sampler.sample(assign_result, anchors,
  221. gt_bboxes)
  222. num_valid_anchors = anchors.shape[0]
  223. bbox_targets = torch.zeros_like(anchors)
  224. bbox_weights = torch.zeros_like(anchors)
  225. labels = anchors.new_full((num_valid_anchors, ),
  226. self.num_classes,
  227. dtype=torch.long)
  228. label_weights = anchors.new_zeros(num_valid_anchors, dtype=torch.float)
  229. pos_inds = sampling_result.pos_inds
  230. neg_inds = sampling_result.neg_inds
  231. if len(pos_inds) > 0:
  232. if not self.reg_decoded_bbox:
  233. pos_bbox_targets = self.bbox_coder.encode(
  234. sampling_result.pos_bboxes, sampling_result.pos_gt_bboxes)
  235. else:
  236. pos_bbox_targets = sampling_result.pos_gt_bboxes
  237. bbox_targets[pos_inds, :] = pos_bbox_targets
  238. bbox_weights[pos_inds, :] = 1.0
  239. if gt_labels is None:
  240. # Only rpn gives gt_labels as None
  241. # Foreground is the first class since v2.5.0
  242. labels[pos_inds] = 0
  243. else:
  244. labels[pos_inds] = gt_labels[
  245. sampling_result.pos_assigned_gt_inds]
  246. if self.train_cfg.pos_weight <= 0:
  247. label_weights[pos_inds] = 1.0
  248. else:
  249. label_weights[pos_inds] = self.train_cfg.pos_weight
  250. if len(neg_inds) > 0:
  251. label_weights[neg_inds] = 1.0
  252. # map up to original set of anchors
  253. if unmap_outputs:
  254. num_total_anchors = flat_anchors.size(0)
  255. labels = unmap(
  256. labels, num_total_anchors, inside_flags,
  257. fill=self.num_classes) # fill bg label
  258. label_weights = unmap(label_weights, num_total_anchors,
  259. inside_flags)
  260. bbox_targets = unmap(bbox_targets, num_total_anchors, inside_flags)
  261. bbox_weights = unmap(bbox_weights, num_total_anchors, inside_flags)
  262. return (labels, label_weights, bbox_targets, bbox_weights, pos_inds,
  263. neg_inds, sampling_result)
  264. def get_targets(self,
  265. anchor_list,
  266. valid_flag_list,
  267. gt_bboxes_list,
  268. img_metas,
  269. gt_bboxes_ignore_list=None,
  270. gt_labels_list=None,
  271. label_channels=1,
  272. unmap_outputs=True,
  273. return_sampling_results=False):
  274. """Compute regression and classification targets for anchors in
  275. multiple images.
  276. Args:
  277. anchor_list (list[list[Tensor]]): Multi level anchors of each
  278. image. The outer list indicates images, and the inner list
  279. corresponds to feature levels of the image. Each element of
  280. the inner list is a tensor of shape (num_anchors, 4).
  281. valid_flag_list (list[list[Tensor]]): Multi level valid flags of
  282. each image. The outer list indicates images, and the inner list
  283. corresponds to feature levels of the image. Each element of
  284. the inner list is a tensor of shape (num_anchors, )
  285. gt_bboxes_list (list[Tensor]): Ground truth bboxes of each image.
  286. img_metas (list[dict]): Meta info of each image.
  287. gt_bboxes_ignore_list (list[Tensor]): Ground truth bboxes to be
  288. ignored.
  289. gt_labels_list (list[Tensor]): Ground truth labels of each box.
  290. label_channels (int): Channel of label.
  291. unmap_outputs (bool): Whether to map outputs back to the original
  292. set of anchors.
  293. Returns:
  294. tuple: Usually returns a tuple containing learning targets.
  295. - labels_list (list[Tensor]): Labels of each level.
  296. - label_weights_list (list[Tensor]): Label weights of each
  297. level.
  298. - bbox_targets_list (list[Tensor]): BBox targets of each level.
  299. - bbox_weights_list (list[Tensor]): BBox weights of each level.
  300. - num_total_pos (int): Number of positive samples in all
  301. images.
  302. - num_total_neg (int): Number of negative samples in all
  303. images.
  304. additional_returns: This function enables user-defined returns from
  305. `self._get_targets_single`. These returns are currently refined
  306. to properties at each feature map (i.e. having HxW dimension).
  307. The results will be concatenated after the end
  308. """
  309. num_imgs = len(img_metas)
  310. assert len(anchor_list) == len(valid_flag_list) == num_imgs
  311. # anchor number of multi levels
  312. num_level_anchors = [anchors.size(0) for anchors in anchor_list[0]]
  313. # concat all level anchors to a single tensor
  314. concat_anchor_list = []
  315. concat_valid_flag_list = []
  316. for i in range(num_imgs):
  317. assert len(anchor_list[i]) == len(valid_flag_list[i])
  318. concat_anchor_list.append(torch.cat(anchor_list[i]))
  319. concat_valid_flag_list.append(torch.cat(valid_flag_list[i]))
  320. # compute targets for each image
  321. if gt_bboxes_ignore_list is None:
  322. gt_bboxes_ignore_list = [None for _ in range(num_imgs)]
  323. if gt_labels_list is None:
  324. gt_labels_list = [None for _ in range(num_imgs)]
  325. results = multi_apply(
  326. self._get_targets_single,
  327. concat_anchor_list,
  328. concat_valid_flag_list,
  329. gt_bboxes_list,
  330. gt_bboxes_ignore_list,
  331. gt_labels_list,
  332. img_metas,
  333. label_channels=label_channels,
  334. unmap_outputs=unmap_outputs)
  335. (all_labels, all_label_weights, all_bbox_targets, all_bbox_weights,
  336. pos_inds_list, neg_inds_list, sampling_results_list) = results[:7]
  337. rest_results = list(results[7:]) # user-added return values
  338. # no valid anchors
  339. if any([labels is None for labels in all_labels]):
  340. return None
  341. # sampled anchors of all images
  342. num_total_pos = sum([max(inds.numel(), 1) for inds in pos_inds_list])
  343. num_total_neg = sum([max(inds.numel(), 1) for inds in neg_inds_list])
  344. # split targets to a list w.r.t. multiple levels
  345. labels_list = images_to_levels(all_labels, num_level_anchors)
  346. label_weights_list = images_to_levels(all_label_weights,
  347. num_level_anchors)
  348. bbox_targets_list = images_to_levels(all_bbox_targets,
  349. num_level_anchors)
  350. bbox_weights_list = images_to_levels(all_bbox_weights,
  351. num_level_anchors)
  352. res = (labels_list, label_weights_list, bbox_targets_list,
  353. bbox_weights_list, num_total_pos, num_total_neg)
  354. if return_sampling_results:
  355. res = res + (sampling_results_list, )
  356. for i, r in enumerate(rest_results): # user-added return values
  357. rest_results[i] = images_to_levels(r, num_level_anchors)
  358. return res + tuple(rest_results)
  359. def loss_single(self, cls_score, bbox_pred, anchors, labels, label_weights,
  360. bbox_targets, bbox_weights, num_total_samples):
  361. """Compute loss of a single scale level.
  362. Args:
  363. cls_score (Tensor): Box scores for each scale level
  364. Has shape (N, num_anchors * num_classes, H, W).
  365. bbox_pred (Tensor): Box energies / deltas for each scale
  366. level with shape (N, num_anchors * 4, H, W).
  367. anchors (Tensor): Box reference for each scale level with shape
  368. (N, num_total_anchors, 4).
  369. labels (Tensor): Labels of each anchors with shape
  370. (N, num_total_anchors).
  371. label_weights (Tensor): Label weights of each anchor with shape
  372. (N, num_total_anchors)
  373. bbox_targets (Tensor): BBox regression targets of each anchor
  374. weight shape (N, num_total_anchors, 4).
  375. bbox_weights (Tensor): BBox regression loss weights of each anchor
  376. with shape (N, num_total_anchors, 4).
  377. num_total_samples (int): If sampling, num total samples equal to
  378. the number of total anchors; Otherwise, it is the number of
  379. positive anchors.
  380. Returns:
  381. dict[str, Tensor]: A dictionary of loss components.
  382. """
  383. # classification loss
  384. labels = labels.reshape(-1)
  385. label_weights = label_weights.reshape(-1)
  386. cls_score = cls_score.permute(0, 2, 3,
  387. 1).reshape(-1, self.cls_out_channels)
  388. loss_cls = self.loss_cls(
  389. cls_score, labels, label_weights, avg_factor=num_total_samples)
  390. # regression loss
  391. bbox_targets = bbox_targets.reshape(-1, 4)
  392. bbox_weights = bbox_weights.reshape(-1, 4)
  393. bbox_pred = bbox_pred.permute(0, 2, 3, 1).reshape(-1, 4)
  394. if self.reg_decoded_bbox:
  395. # When the regression loss (e.g. `IouLoss`, `GIouLoss`)
  396. # is applied directly on the decoded bounding boxes, it
  397. # decodes the already encoded coordinates to absolute format.
  398. anchors = anchors.reshape(-1, 4)
  399. bbox_pred = self.bbox_coder.decode(anchors, bbox_pred)
  400. loss_bbox = self.loss_bbox(
  401. bbox_pred,
  402. bbox_targets,
  403. bbox_weights,
  404. avg_factor=num_total_samples)
  405. return loss_cls, loss_bbox
  406. @force_fp32(apply_to=('cls_scores', 'bbox_preds'))
  407. def loss(self,
  408. cls_scores,
  409. bbox_preds,
  410. gt_bboxes,
  411. gt_labels,
  412. img_metas,
  413. gt_bboxes_ignore=None):
  414. """Compute losses of the head.
  415. Args:
  416. cls_scores (list[Tensor]): Box scores for each scale level
  417. Has shape (N, num_anchors * num_classes, H, W)
  418. bbox_preds (list[Tensor]): Box energies / deltas for each scale
  419. level with shape (N, num_anchors * 4, H, W)
  420. gt_bboxes (list[Tensor]): Ground truth bboxes for each image with
  421. shape (num_gts, 4) in [tl_x, tl_y, br_x, br_y] format.
  422. gt_labels (list[Tensor]): class indices corresponding to each box
  423. img_metas (list[dict]): Meta information of each image, e.g.,
  424. image size, scaling factor, etc.
  425. gt_bboxes_ignore (None | list[Tensor]): specify which bounding
  426. boxes can be ignored when computing the loss. Default: None
  427. Returns:
  428. dict[str, Tensor]: A dictionary of loss components.
  429. """
  430. featmap_sizes = [featmap.size()[-2:] for featmap in cls_scores]
  431. assert len(featmap_sizes) == self.prior_generator.num_levels
  432. device = cls_scores[0].device
  433. anchor_list, valid_flag_list = self.get_anchors(
  434. featmap_sizes, img_metas, device=device)
  435. label_channels = self.cls_out_channels if self.use_sigmoid_cls else 1
  436. cls_reg_targets = self.get_targets(
  437. anchor_list,
  438. valid_flag_list,
  439. gt_bboxes,
  440. img_metas,
  441. gt_bboxes_ignore_list=gt_bboxes_ignore,
  442. gt_labels_list=gt_labels,
  443. label_channels=label_channels)
  444. if cls_reg_targets is None:
  445. return None
  446. (labels_list, label_weights_list, bbox_targets_list, bbox_weights_list,
  447. num_total_pos, num_total_neg) = cls_reg_targets
  448. num_total_samples = (
  449. num_total_pos + num_total_neg if self.sampling else num_total_pos)
  450. # anchor number of multi levels
  451. num_level_anchors = [anchors.size(0) for anchors in anchor_list[0]]
  452. # concat all level anchors and flags to a single tensor
  453. concat_anchor_list = []
  454. for i in range(len(anchor_list)):
  455. concat_anchor_list.append(torch.cat(anchor_list[i]))
  456. all_anchor_list = images_to_levels(concat_anchor_list,
  457. num_level_anchors)
  458. losses_cls, losses_bbox = multi_apply(
  459. self.loss_single,
  460. cls_scores,
  461. bbox_preds,
  462. all_anchor_list,
  463. labels_list,
  464. label_weights_list,
  465. bbox_targets_list,
  466. bbox_weights_list,
  467. num_total_samples=num_total_samples)
  468. return dict(loss_cls=losses_cls, loss_bbox=losses_bbox)
  469. def aug_test(self, feats, img_metas, rescale=False):
  470. """Test function with test time augmentation.
  471. Args:
  472. feats (list[Tensor]): the outer list indicates test-time
  473. augmentations and inner Tensor should have a shape NxCxHxW,
  474. which contains features for all images in the batch.
  475. img_metas (list[list[dict]]): the outer list indicates test-time
  476. augs (multiscale, flip, etc.) and the inner list indicates
  477. images in a batch. each dict has image information.
  478. rescale (bool, optional): Whether to rescale the results.
  479. Defaults to False.
  480. Returns:
  481. list[tuple[Tensor, Tensor]]: Each item in result_list is 2-tuple.
  482. The first item is ``bboxes`` with shape (n, 5), where
  483. 5 represent (tl_x, tl_y, br_x, br_y, score).
  484. The shape of the second tensor in the tuple is ``labels``
  485. with shape (n,), The length of list should always be 1.
  486. """
  487. return self.aug_test_bboxes(feats, img_metas, rescale=rescale)

No Description

Contributors (3)