You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

gfl_head.py 28 kB

2 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648
  1. # Copyright (c) OpenMMLab. All rights reserved.
  2. import torch
  3. import torch.nn as nn
  4. import torch.nn.functional as F
  5. from mmcv.cnn import ConvModule, Scale
  6. from mmcv.runner import force_fp32
  7. from mmdet.core import (anchor_inside_flags, bbox_overlaps, build_assigner,
  8. build_sampler, images_to_levels, multi_apply,
  9. reduce_mean, unmap)
  10. from mmdet.core.utils import filter_scores_and_topk
  11. from ..builder import HEADS, build_loss
  12. from .anchor_head import AnchorHead
  13. class Integral(nn.Module):
  14. """A fixed layer for calculating integral result from distribution.
  15. This layer calculates the target location by :math: `sum{P(y_i) * y_i}`,
  16. P(y_i) denotes the softmax vector that represents the discrete distribution
  17. y_i denotes the discrete set, usually {0, 1, 2, ..., reg_max}
  18. Args:
  19. reg_max (int): The maximal value of the discrete set. Default: 16. You
  20. may want to reset it according to your new dataset or related
  21. settings.
  22. """
  23. def __init__(self, reg_max=16):
  24. super(Integral, self).__init__()
  25. self.reg_max = reg_max
  26. self.register_buffer('project',
  27. torch.linspace(0, self.reg_max, self.reg_max + 1))
  28. def forward(self, x):
  29. """Forward feature from the regression head to get integral result of
  30. bounding box location.
  31. Args:
  32. x (Tensor): Features of the regression head, shape (N, 4*(n+1)),
  33. n is self.reg_max.
  34. Returns:
  35. x (Tensor): Integral result of box locations, i.e., distance
  36. offsets from the box center in four directions, shape (N, 4).
  37. """
  38. x = F.softmax(x.reshape(-1, self.reg_max + 1), dim=1)
  39. x = F.linear(x, self.project.type_as(x)).reshape(-1, 4)
  40. return x
  41. @HEADS.register_module()
  42. class GFLHead(AnchorHead):
  43. """Generalized Focal Loss: Learning Qualified and Distributed Bounding
  44. Boxes for Dense Object Detection.
  45. GFL head structure is similar with ATSS, however GFL uses
  46. 1) joint representation for classification and localization quality, and
  47. 2) flexible General distribution for bounding box locations,
  48. which are supervised by
  49. Quality Focal Loss (QFL) and Distribution Focal Loss (DFL), respectively
  50. https://arxiv.org/abs/2006.04388
  51. Args:
  52. num_classes (int): Number of categories excluding the background
  53. category.
  54. in_channels (int): Number of channels in the input feature map.
  55. stacked_convs (int): Number of conv layers in cls and reg tower.
  56. Default: 4.
  57. conv_cfg (dict): dictionary to construct and config conv layer.
  58. Default: None.
  59. norm_cfg (dict): dictionary to construct and config norm layer.
  60. Default: dict(type='GN', num_groups=32, requires_grad=True).
  61. loss_qfl (dict): Config of Quality Focal Loss (QFL).
  62. bbox_coder (dict): Config of bbox coder. Defaults
  63. 'DistancePointBBoxCoder'.
  64. reg_max (int): Max value of integral set :math: `{0, ..., reg_max}`
  65. in QFL setting. Default: 16.
  66. init_cfg (dict or list[dict], optional): Initialization config dict.
  67. Example:
  68. >>> self = GFLHead(11, 7)
  69. >>> feats = [torch.rand(1, 7, s, s) for s in [4, 8, 16, 32, 64]]
  70. >>> cls_quality_score, bbox_pred = self.forward(feats)
  71. >>> assert len(cls_quality_score) == len(self.scales)
  72. """
  73. def __init__(self,
  74. num_classes,
  75. in_channels,
  76. stacked_convs=4,
  77. conv_cfg=None,
  78. norm_cfg=dict(type='GN', num_groups=32, requires_grad=True),
  79. loss_dfl=dict(type='DistributionFocalLoss', loss_weight=0.25),
  80. bbox_coder=dict(type='DistancePointBBoxCoder'),
  81. reg_max=16,
  82. init_cfg=dict(
  83. type='Normal',
  84. layer='Conv2d',
  85. std=0.01,
  86. override=dict(
  87. type='Normal',
  88. name='gfl_cls',
  89. std=0.01,
  90. bias_prob=0.01)),
  91. **kwargs):
  92. self.stacked_convs = stacked_convs
  93. self.conv_cfg = conv_cfg
  94. self.norm_cfg = norm_cfg
  95. self.reg_max = reg_max
  96. super(GFLHead, self).__init__(
  97. num_classes,
  98. in_channels,
  99. bbox_coder=bbox_coder,
  100. init_cfg=init_cfg,
  101. **kwargs)
  102. self.sampling = False
  103. if self.train_cfg:
  104. self.assigner = build_assigner(self.train_cfg.assigner)
  105. # SSD sampling=False so use PseudoSampler
  106. sampler_cfg = dict(type='PseudoSampler')
  107. self.sampler = build_sampler(sampler_cfg, context=self)
  108. self.integral = Integral(self.reg_max)
  109. self.loss_dfl = build_loss(loss_dfl)
  110. def _init_layers(self):
  111. """Initialize layers of the head."""
  112. self.relu = nn.ReLU(inplace=True)
  113. self.cls_convs = nn.ModuleList()
  114. self.reg_convs = nn.ModuleList()
  115. for i in range(self.stacked_convs):
  116. chn = self.in_channels if i == 0 else self.feat_channels
  117. self.cls_convs.append(
  118. ConvModule(
  119. chn,
  120. self.feat_channels,
  121. 3,
  122. stride=1,
  123. padding=1,
  124. conv_cfg=self.conv_cfg,
  125. norm_cfg=self.norm_cfg))
  126. self.reg_convs.append(
  127. ConvModule(
  128. chn,
  129. self.feat_channels,
  130. 3,
  131. stride=1,
  132. padding=1,
  133. conv_cfg=self.conv_cfg,
  134. norm_cfg=self.norm_cfg))
  135. assert self.num_anchors == 1, 'anchor free version'
  136. self.gfl_cls = nn.Conv2d(
  137. self.feat_channels, self.cls_out_channels, 3, padding=1)
  138. self.gfl_reg = nn.Conv2d(
  139. self.feat_channels, 4 * (self.reg_max + 1), 3, padding=1)
  140. self.scales = nn.ModuleList(
  141. [Scale(1.0) for _ in self.prior_generator.strides])
  142. def forward(self, feats):
  143. """Forward features from the upstream network.
  144. Args:
  145. feats (tuple[Tensor]): Features from the upstream network, each is
  146. a 4D-tensor.
  147. Returns:
  148. tuple: Usually a tuple of classification scores and bbox prediction
  149. cls_scores (list[Tensor]): Classification and quality (IoU)
  150. joint scores for all scale levels, each is a 4D-tensor,
  151. the channel number is num_classes.
  152. bbox_preds (list[Tensor]): Box distribution logits for all
  153. scale levels, each is a 4D-tensor, the channel number is
  154. 4*(n+1), n is max value of integral set.
  155. """
  156. return multi_apply(self.forward_single, feats, self.scales)
  157. def forward_single(self, x, scale):
  158. """Forward feature of a single scale level.
  159. Args:
  160. x (Tensor): Features of a single scale level.
  161. scale (:obj: `mmcv.cnn.Scale`): Learnable scale module to resize
  162. the bbox prediction.
  163. Returns:
  164. tuple:
  165. cls_score (Tensor): Cls and quality joint scores for a single
  166. scale level the channel number is num_classes.
  167. bbox_pred (Tensor): Box distribution logits for a single scale
  168. level, the channel number is 4*(n+1), n is max value of
  169. integral set.
  170. """
  171. cls_feat = x
  172. reg_feat = x
  173. for cls_conv in self.cls_convs:
  174. cls_feat = cls_conv(cls_feat)
  175. for reg_conv in self.reg_convs:
  176. reg_feat = reg_conv(reg_feat)
  177. cls_score = self.gfl_cls(cls_feat)
  178. bbox_pred = scale(self.gfl_reg(reg_feat)).float()
  179. return cls_score, bbox_pred
  180. def anchor_center(self, anchors):
  181. """Get anchor centers from anchors.
  182. Args:
  183. anchors (Tensor): Anchor list with shape (N, 4), "xyxy" format.
  184. Returns:
  185. Tensor: Anchor centers with shape (N, 2), "xy" format.
  186. """
  187. anchors_cx = (anchors[..., 2] + anchors[..., 0]) / 2
  188. anchors_cy = (anchors[..., 3] + anchors[..., 1]) / 2
  189. return torch.stack([anchors_cx, anchors_cy], dim=-1)
  190. def loss_single(self, anchors, cls_score, bbox_pred, labels, label_weights,
  191. bbox_targets, stride, num_total_samples):
  192. """Compute loss of a single scale level.
  193. Args:
  194. anchors (Tensor): Box reference for each scale level with shape
  195. (N, num_total_anchors, 4).
  196. cls_score (Tensor): Cls and quality joint scores for each scale
  197. level has shape (N, num_classes, H, W).
  198. bbox_pred (Tensor): Box distribution logits for each scale
  199. level with shape (N, 4*(n+1), H, W), n is max value of integral
  200. set.
  201. labels (Tensor): Labels of each anchors with shape
  202. (N, num_total_anchors).
  203. label_weights (Tensor): Label weights of each anchor with shape
  204. (N, num_total_anchors)
  205. bbox_targets (Tensor): BBox regression targets of each anchor
  206. weight shape (N, num_total_anchors, 4).
  207. stride (tuple): Stride in this scale level.
  208. num_total_samples (int): Number of positive samples that is
  209. reduced over all GPUs.
  210. Returns:
  211. dict[str, Tensor]: A dictionary of loss components.
  212. """
  213. assert stride[0] == stride[1], 'h stride is not equal to w stride!'
  214. anchors = anchors.reshape(-1, 4)
  215. cls_score = cls_score.permute(0, 2, 3,
  216. 1).reshape(-1, self.cls_out_channels)
  217. bbox_pred = bbox_pred.permute(0, 2, 3,
  218. 1).reshape(-1, 4 * (self.reg_max + 1))
  219. bbox_targets = bbox_targets.reshape(-1, 4)
  220. labels = labels.reshape(-1)
  221. label_weights = label_weights.reshape(-1)
  222. # FG cat_id: [0, num_classes -1], BG cat_id: num_classes
  223. bg_class_ind = self.num_classes
  224. pos_inds = ((labels >= 0)
  225. & (labels < bg_class_ind)).nonzero().squeeze(1)
  226. score = label_weights.new_zeros(labels.shape)
  227. if len(pos_inds) > 0:
  228. pos_bbox_targets = bbox_targets[pos_inds]
  229. pos_bbox_pred = bbox_pred[pos_inds]
  230. pos_anchors = anchors[pos_inds]
  231. pos_anchor_centers = self.anchor_center(pos_anchors) / stride[0]
  232. weight_targets = cls_score.detach().sigmoid()
  233. weight_targets = weight_targets.max(dim=1)[0][pos_inds]
  234. pos_bbox_pred_corners = self.integral(pos_bbox_pred)
  235. pos_decode_bbox_pred = self.bbox_coder.decode(
  236. pos_anchor_centers, pos_bbox_pred_corners)
  237. pos_decode_bbox_targets = pos_bbox_targets / stride[0]
  238. score[pos_inds] = bbox_overlaps(
  239. pos_decode_bbox_pred.detach(),
  240. pos_decode_bbox_targets,
  241. is_aligned=True)
  242. pred_corners = pos_bbox_pred.reshape(-1, self.reg_max + 1)
  243. target_corners = self.bbox_coder.encode(pos_anchor_centers,
  244. pos_decode_bbox_targets,
  245. self.reg_max).reshape(-1)
  246. # regression loss
  247. loss_bbox = self.loss_bbox(
  248. pos_decode_bbox_pred,
  249. pos_decode_bbox_targets,
  250. weight=weight_targets,
  251. avg_factor=1.0)
  252. # dfl loss
  253. loss_dfl = self.loss_dfl(
  254. pred_corners,
  255. target_corners,
  256. weight=weight_targets[:, None].expand(-1, 4).reshape(-1),
  257. avg_factor=4.0)
  258. else:
  259. loss_bbox = bbox_pred.sum() * 0
  260. loss_dfl = bbox_pred.sum() * 0
  261. weight_targets = bbox_pred.new_tensor(0)
  262. # cls (qfl) loss
  263. loss_cls = self.loss_cls(
  264. cls_score, (labels, score),
  265. weight=label_weights,
  266. avg_factor=num_total_samples)
  267. return loss_cls, loss_bbox, loss_dfl, weight_targets.sum()
  268. @force_fp32(apply_to=('cls_scores', 'bbox_preds'))
  269. def loss(self,
  270. cls_scores,
  271. bbox_preds,
  272. gt_bboxes,
  273. gt_labels,
  274. img_metas,
  275. gt_bboxes_ignore=None):
  276. """Compute losses of the head.
  277. Args:
  278. cls_scores (list[Tensor]): Cls and quality scores for each scale
  279. level has shape (N, num_classes, H, W).
  280. bbox_preds (list[Tensor]): Box distribution logits for each scale
  281. level with shape (N, 4*(n+1), H, W), n is max value of integral
  282. set.
  283. gt_bboxes (list[Tensor]): Ground truth bboxes for each image with
  284. shape (num_gts, 4) in [tl_x, tl_y, br_x, br_y] format.
  285. gt_labels (list[Tensor]): class indices corresponding to each box
  286. img_metas (list[dict]): Meta information of each image, e.g.,
  287. image size, scaling factor, etc.
  288. gt_bboxes_ignore (list[Tensor] | None): specify which bounding
  289. boxes can be ignored when computing the loss.
  290. Returns:
  291. dict[str, Tensor]: A dictionary of loss components.
  292. """
  293. featmap_sizes = [featmap.size()[-2:] for featmap in cls_scores]
  294. assert len(featmap_sizes) == self.prior_generator.num_levels
  295. device = cls_scores[0].device
  296. anchor_list, valid_flag_list = self.get_anchors(
  297. featmap_sizes, img_metas, device=device)
  298. label_channels = self.cls_out_channels if self.use_sigmoid_cls else 1
  299. cls_reg_targets = self.get_targets(
  300. anchor_list,
  301. valid_flag_list,
  302. gt_bboxes,
  303. img_metas,
  304. gt_bboxes_ignore_list=gt_bboxes_ignore,
  305. gt_labels_list=gt_labels,
  306. label_channels=label_channels)
  307. if cls_reg_targets is None:
  308. return None
  309. (anchor_list, labels_list, label_weights_list, bbox_targets_list,
  310. bbox_weights_list, num_total_pos, num_total_neg) = cls_reg_targets
  311. num_total_samples = reduce_mean(
  312. torch.tensor(num_total_pos, dtype=torch.float,
  313. device=device)).item()
  314. num_total_samples = max(num_total_samples, 1.0)
  315. losses_cls, losses_bbox, losses_dfl,\
  316. avg_factor = multi_apply(
  317. self.loss_single,
  318. anchor_list,
  319. cls_scores,
  320. bbox_preds,
  321. labels_list,
  322. label_weights_list,
  323. bbox_targets_list,
  324. self.prior_generator.strides,
  325. num_total_samples=num_total_samples)
  326. avg_factor = sum(avg_factor)
  327. avg_factor = reduce_mean(avg_factor).clamp_(min=1).item()
  328. losses_bbox = list(map(lambda x: x / avg_factor, losses_bbox))
  329. losses_dfl = list(map(lambda x: x / avg_factor, losses_dfl))
  330. return dict(
  331. loss_cls=losses_cls, loss_bbox=losses_bbox, loss_dfl=losses_dfl)
  332. def _get_bboxes_single(self,
  333. cls_score_list,
  334. bbox_pred_list,
  335. score_factor_list,
  336. mlvl_priors,
  337. img_meta,
  338. cfg,
  339. rescale=False,
  340. with_nms=True,
  341. **kwargs):
  342. """Transform outputs of a single image into bbox predictions.
  343. Args:
  344. cls_score_list (list[Tensor]): Box scores from all scale
  345. levels of a single image, each item has shape
  346. (num_priors * num_classes, H, W).
  347. bbox_pred_list (list[Tensor]): Box energies / deltas from
  348. all scale levels of a single image, each item has shape
  349. (num_priors * 4, H, W).
  350. score_factor_list (list[Tensor]): Score factor from all scale
  351. levels of a single image. GFL head does not need this value.
  352. mlvl_priors (list[Tensor]): Each element in the list is
  353. the priors of a single level in feature pyramid, has shape
  354. (num_priors, 4).
  355. img_meta (dict): Image meta info.
  356. cfg (mmcv.Config): Test / postprocessing configuration,
  357. if None, test_cfg would be used.
  358. rescale (bool): If True, return boxes in original image space.
  359. Default: False.
  360. with_nms (bool): If True, do nms before return boxes.
  361. Default: True.
  362. Returns:
  363. tuple[Tensor]: Results of detected bboxes and labels. If with_nms
  364. is False and mlvl_score_factor is None, return mlvl_bboxes and
  365. mlvl_scores, else return mlvl_bboxes, mlvl_scores and
  366. mlvl_score_factor. Usually with_nms is False is used for aug
  367. test. If with_nms is True, then return the following format
  368. - det_bboxes (Tensor): Predicted bboxes with shape \
  369. [num_bboxes, 5], where the first 4 columns are bounding \
  370. box positions (tl_x, tl_y, br_x, br_y) and the 5-th \
  371. column are scores between 0 and 1.
  372. - det_labels (Tensor): Predicted labels of the corresponding \
  373. box with shape [num_bboxes].
  374. """
  375. cfg = self.test_cfg if cfg is None else cfg
  376. img_shape = img_meta['img_shape']
  377. nms_pre = cfg.get('nms_pre', -1)
  378. mlvl_bboxes = []
  379. mlvl_scores = []
  380. mlvl_labels = []
  381. for level_idx, (cls_score, bbox_pred, stride, priors) in enumerate(
  382. zip(cls_score_list, bbox_pred_list,
  383. self.prior_generator.strides, mlvl_priors)):
  384. assert cls_score.size()[-2:] == bbox_pred.size()[-2:]
  385. assert stride[0] == stride[1]
  386. bbox_pred = bbox_pred.permute(1, 2, 0)
  387. bbox_pred = self.integral(bbox_pred) * stride[0]
  388. scores = cls_score.permute(1, 2, 0).reshape(
  389. -1, self.cls_out_channels).sigmoid()
  390. # After https://github.com/open-mmlab/mmdetection/pull/6268/,
  391. # this operation keeps fewer bboxes under the same `nms_pre`.
  392. # There is no difference in performance for most models. If you
  393. # find a slight drop in performance, you can set a larger
  394. # `nms_pre` than before.
  395. results = filter_scores_and_topk(
  396. scores, cfg.score_thr, nms_pre,
  397. dict(bbox_pred=bbox_pred, priors=priors))
  398. scores, labels, _, filtered_results = results
  399. bbox_pred = filtered_results['bbox_pred']
  400. priors = filtered_results['priors']
  401. bboxes = self.bbox_coder.decode(
  402. self.anchor_center(priors), bbox_pred, max_shape=img_shape)
  403. mlvl_bboxes.append(bboxes)
  404. mlvl_scores.append(scores)
  405. mlvl_labels.append(labels)
  406. return self._bbox_post_process(
  407. mlvl_scores,
  408. mlvl_labels,
  409. mlvl_bboxes,
  410. img_meta['scale_factor'],
  411. cfg,
  412. rescale=rescale,
  413. with_nms=with_nms)
  414. def get_targets(self,
  415. anchor_list,
  416. valid_flag_list,
  417. gt_bboxes_list,
  418. img_metas,
  419. gt_bboxes_ignore_list=None,
  420. gt_labels_list=None,
  421. label_channels=1,
  422. unmap_outputs=True):
  423. """Get targets for GFL head.
  424. This method is almost the same as `AnchorHead.get_targets()`. Besides
  425. returning the targets as the parent method does, it also returns the
  426. anchors as the first element of the returned tuple.
  427. """
  428. num_imgs = len(img_metas)
  429. assert len(anchor_list) == len(valid_flag_list) == num_imgs
  430. # anchor number of multi levels
  431. num_level_anchors = [anchors.size(0) for anchors in anchor_list[0]]
  432. num_level_anchors_list = [num_level_anchors] * num_imgs
  433. # concat all level anchors and flags to a single tensor
  434. for i in range(num_imgs):
  435. assert len(anchor_list[i]) == len(valid_flag_list[i])
  436. anchor_list[i] = torch.cat(anchor_list[i])
  437. valid_flag_list[i] = torch.cat(valid_flag_list[i])
  438. # compute targets for each image
  439. if gt_bboxes_ignore_list is None:
  440. gt_bboxes_ignore_list = [None for _ in range(num_imgs)]
  441. if gt_labels_list is None:
  442. gt_labels_list = [None for _ in range(num_imgs)]
  443. (all_anchors, all_labels, all_label_weights, all_bbox_targets,
  444. all_bbox_weights, pos_inds_list, neg_inds_list) = multi_apply(
  445. self._get_target_single,
  446. anchor_list,
  447. valid_flag_list,
  448. num_level_anchors_list,
  449. gt_bboxes_list,
  450. gt_bboxes_ignore_list,
  451. gt_labels_list,
  452. img_metas,
  453. label_channels=label_channels,
  454. unmap_outputs=unmap_outputs)
  455. # no valid anchors
  456. if any([labels is None for labels in all_labels]):
  457. return None
  458. # sampled anchors of all images
  459. num_total_pos = sum([max(inds.numel(), 1) for inds in pos_inds_list])
  460. num_total_neg = sum([max(inds.numel(), 1) for inds in neg_inds_list])
  461. # split targets to a list w.r.t. multiple levels
  462. anchors_list = images_to_levels(all_anchors, num_level_anchors)
  463. labels_list = images_to_levels(all_labels, num_level_anchors)
  464. label_weights_list = images_to_levels(all_label_weights,
  465. num_level_anchors)
  466. bbox_targets_list = images_to_levels(all_bbox_targets,
  467. num_level_anchors)
  468. bbox_weights_list = images_to_levels(all_bbox_weights,
  469. num_level_anchors)
  470. return (anchors_list, labels_list, label_weights_list,
  471. bbox_targets_list, bbox_weights_list, num_total_pos,
  472. num_total_neg)
  473. def _get_target_single(self,
  474. flat_anchors,
  475. valid_flags,
  476. num_level_anchors,
  477. gt_bboxes,
  478. gt_bboxes_ignore,
  479. gt_labels,
  480. img_meta,
  481. label_channels=1,
  482. unmap_outputs=True):
  483. """Compute regression, classification targets for anchors in a single
  484. image.
  485. Args:
  486. flat_anchors (Tensor): Multi-level anchors of the image, which are
  487. concatenated into a single tensor of shape (num_anchors, 4)
  488. valid_flags (Tensor): Multi level valid flags of the image,
  489. which are concatenated into a single tensor of
  490. shape (num_anchors,).
  491. num_level_anchors Tensor): Number of anchors of each scale level.
  492. gt_bboxes (Tensor): Ground truth bboxes of the image,
  493. shape (num_gts, 4).
  494. gt_bboxes_ignore (Tensor): Ground truth bboxes to be
  495. ignored, shape (num_ignored_gts, 4).
  496. gt_labels (Tensor): Ground truth labels of each box,
  497. shape (num_gts,).
  498. img_meta (dict): Meta info of the image.
  499. label_channels (int): Channel of label.
  500. unmap_outputs (bool): Whether to map outputs back to the original
  501. set of anchors.
  502. Returns:
  503. tuple: N is the number of total anchors in the image.
  504. anchors (Tensor): All anchors in the image with shape (N, 4).
  505. labels (Tensor): Labels of all anchors in the image with shape
  506. (N,).
  507. label_weights (Tensor): Label weights of all anchor in the
  508. image with shape (N,).
  509. bbox_targets (Tensor): BBox targets of all anchors in the
  510. image with shape (N, 4).
  511. bbox_weights (Tensor): BBox weights of all anchors in the
  512. image with shape (N, 4).
  513. pos_inds (Tensor): Indices of positive anchor with shape
  514. (num_pos,).
  515. neg_inds (Tensor): Indices of negative anchor with shape
  516. (num_neg,).
  517. """
  518. inside_flags = anchor_inside_flags(flat_anchors, valid_flags,
  519. img_meta['img_shape'][:2],
  520. self.train_cfg.allowed_border)
  521. if not inside_flags.any():
  522. return (None, ) * 7
  523. # assign gt and sample anchors
  524. anchors = flat_anchors[inside_flags, :]
  525. num_level_anchors_inside = self.get_num_level_anchors_inside(
  526. num_level_anchors, inside_flags)
  527. assign_result = self.assigner.assign(anchors, num_level_anchors_inside,
  528. gt_bboxes, gt_bboxes_ignore,
  529. gt_labels)
  530. sampling_result = self.sampler.sample(assign_result, anchors,
  531. gt_bboxes)
  532. num_valid_anchors = anchors.shape[0]
  533. bbox_targets = torch.zeros_like(anchors)
  534. bbox_weights = torch.zeros_like(anchors)
  535. labels = anchors.new_full((num_valid_anchors, ),
  536. self.num_classes,
  537. dtype=torch.long)
  538. label_weights = anchors.new_zeros(num_valid_anchors, dtype=torch.float)
  539. pos_inds = sampling_result.pos_inds
  540. neg_inds = sampling_result.neg_inds
  541. if len(pos_inds) > 0:
  542. pos_bbox_targets = sampling_result.pos_gt_bboxes
  543. bbox_targets[pos_inds, :] = pos_bbox_targets
  544. bbox_weights[pos_inds, :] = 1.0
  545. if gt_labels is None:
  546. # Only rpn gives gt_labels as None
  547. # Foreground is the first class
  548. labels[pos_inds] = 0
  549. else:
  550. labels[pos_inds] = gt_labels[
  551. sampling_result.pos_assigned_gt_inds]
  552. if self.train_cfg.pos_weight <= 0:
  553. label_weights[pos_inds] = 1.0
  554. else:
  555. label_weights[pos_inds] = self.train_cfg.pos_weight
  556. if len(neg_inds) > 0:
  557. label_weights[neg_inds] = 1.0
  558. # map up to original set of anchors
  559. if unmap_outputs:
  560. num_total_anchors = flat_anchors.size(0)
  561. anchors = unmap(anchors, num_total_anchors, inside_flags)
  562. labels = unmap(
  563. labels, num_total_anchors, inside_flags, fill=self.num_classes)
  564. label_weights = unmap(label_weights, num_total_anchors,
  565. inside_flags)
  566. bbox_targets = unmap(bbox_targets, num_total_anchors, inside_flags)
  567. bbox_weights = unmap(bbox_weights, num_total_anchors, inside_flags)
  568. return (anchors, labels, label_weights, bbox_targets, bbox_weights,
  569. pos_inds, neg_inds)
  570. def get_num_level_anchors_inside(self, num_level_anchors, inside_flags):
  571. split_inside_flags = torch.split(inside_flags, num_level_anchors)
  572. num_level_anchors_inside = [
  573. int(flags.sum()) for flags in split_inside_flags
  574. ]
  575. return num_level_anchors_inside

No Description

Contributors (2)