You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

atss_head.py 21 kB

2 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502
  1. # Copyright (c) OpenMMLab. All rights reserved.
  2. import torch
  3. import torch.nn as nn
  4. from mmcv.cnn import ConvModule, Scale
  5. from mmcv.runner import force_fp32
  6. from mmdet.core import (anchor_inside_flags, build_assigner, build_sampler,
  7. images_to_levels, multi_apply, reduce_mean, unmap)
  8. from ..builder import HEADS, build_loss
  9. from .anchor_head import AnchorHead
  10. @HEADS.register_module()
  11. class ATSSHead(AnchorHead):
  12. """Bridging the Gap Between Anchor-based and Anchor-free Detection via
  13. Adaptive Training Sample Selection.
  14. ATSS head structure is similar with FCOS, however ATSS use anchor boxes
  15. and assign label by Adaptive Training Sample Selection instead max-iou.
  16. https://arxiv.org/abs/1912.02424
  17. """
  18. def __init__(self,
  19. num_classes,
  20. in_channels,
  21. stacked_convs=4,
  22. conv_cfg=None,
  23. norm_cfg=dict(type='GN', num_groups=32, requires_grad=True),
  24. reg_decoded_bbox=True,
  25. loss_centerness=dict(
  26. type='CrossEntropyLoss',
  27. use_sigmoid=True,
  28. loss_weight=1.0),
  29. init_cfg=dict(
  30. type='Normal',
  31. layer='Conv2d',
  32. std=0.01,
  33. override=dict(
  34. type='Normal',
  35. name='atss_cls',
  36. std=0.01,
  37. bias_prob=0.01)),
  38. **kwargs):
  39. self.stacked_convs = stacked_convs
  40. self.conv_cfg = conv_cfg
  41. self.norm_cfg = norm_cfg
  42. super(ATSSHead, self).__init__(
  43. num_classes,
  44. in_channels,
  45. reg_decoded_bbox=reg_decoded_bbox,
  46. init_cfg=init_cfg,
  47. **kwargs)
  48. self.sampling = False
  49. if self.train_cfg:
  50. self.assigner = build_assigner(self.train_cfg.assigner)
  51. # SSD sampling=False so use PseudoSampler
  52. sampler_cfg = dict(type='PseudoSampler')
  53. self.sampler = build_sampler(sampler_cfg, context=self)
  54. self.loss_centerness = build_loss(loss_centerness)
  55. def _init_layers(self):
  56. """Initialize layers of the head."""
  57. self.relu = nn.ReLU(inplace=True)
  58. self.cls_convs = nn.ModuleList()
  59. self.reg_convs = nn.ModuleList()
  60. for i in range(self.stacked_convs):
  61. chn = self.in_channels if i == 0 else self.feat_channels
  62. self.cls_convs.append(
  63. ConvModule(
  64. chn,
  65. self.feat_channels,
  66. 3,
  67. stride=1,
  68. padding=1,
  69. conv_cfg=self.conv_cfg,
  70. norm_cfg=self.norm_cfg))
  71. self.reg_convs.append(
  72. ConvModule(
  73. chn,
  74. self.feat_channels,
  75. 3,
  76. stride=1,
  77. padding=1,
  78. conv_cfg=self.conv_cfg,
  79. norm_cfg=self.norm_cfg))
  80. self.atss_cls = nn.Conv2d(
  81. self.feat_channels,
  82. self.num_anchors * self.cls_out_channels,
  83. 3,
  84. padding=1)
  85. self.atss_reg = nn.Conv2d(
  86. self.feat_channels, self.num_base_priors * 4, 3, padding=1)
  87. self.atss_centerness = nn.Conv2d(
  88. self.feat_channels, self.num_base_priors * 1, 3, padding=1)
  89. self.scales = nn.ModuleList(
  90. [Scale(1.0) for _ in self.prior_generator.strides])
  91. def forward(self, feats):
  92. """Forward features from the upstream network.
  93. Args:
  94. feats (tuple[Tensor]): Features from the upstream network, each is
  95. a 4D-tensor.
  96. Returns:
  97. tuple: Usually a tuple of classification scores and bbox prediction
  98. cls_scores (list[Tensor]): Classification scores for all scale
  99. levels, each is a 4D-tensor, the channels number is
  100. num_anchors * num_classes.
  101. bbox_preds (list[Tensor]): Box energies / deltas for all scale
  102. levels, each is a 4D-tensor, the channels number is
  103. num_anchors * 4.
  104. """
  105. return multi_apply(self.forward_single, feats, self.scales)
  106. def forward_single(self, x, scale):
  107. """Forward feature of a single scale level.
  108. Args:
  109. x (Tensor): Features of a single scale level.
  110. scale (:obj: `mmcv.cnn.Scale`): Learnable scale module to resize
  111. the bbox prediction.
  112. Returns:
  113. tuple:
  114. cls_score (Tensor): Cls scores for a single scale level
  115. the channels number is num_anchors * num_classes.
  116. bbox_pred (Tensor): Box energies / deltas for a single scale
  117. level, the channels number is num_anchors * 4.
  118. centerness (Tensor): Centerness for a single scale level, the
  119. channel number is (N, num_anchors * 1, H, W).
  120. """
  121. cls_feat = x
  122. reg_feat = x
  123. for cls_conv in self.cls_convs:
  124. cls_feat = cls_conv(cls_feat)
  125. for reg_conv in self.reg_convs:
  126. reg_feat = reg_conv(reg_feat)
  127. cls_score = self.atss_cls(cls_feat)
  128. # we just follow atss, not apply exp in bbox_pred
  129. bbox_pred = scale(self.atss_reg(reg_feat)).float()
  130. centerness = self.atss_centerness(reg_feat)
  131. return cls_score, bbox_pred, centerness
  132. def loss_single(self, anchors, cls_score, bbox_pred, centerness, labels,
  133. label_weights, bbox_targets, num_total_samples):
  134. """Compute loss of a single scale level.
  135. Args:
  136. cls_score (Tensor): Box scores for each scale level
  137. Has shape (N, num_anchors * num_classes, H, W).
  138. bbox_pred (Tensor): Box energies / deltas for each scale
  139. level with shape (N, num_anchors * 4, H, W).
  140. anchors (Tensor): Box reference for each scale level with shape
  141. (N, num_total_anchors, 4).
  142. labels (Tensor): Labels of each anchors with shape
  143. (N, num_total_anchors).
  144. label_weights (Tensor): Label weights of each anchor with shape
  145. (N, num_total_anchors)
  146. bbox_targets (Tensor): BBox regression targets of each anchor
  147. weight shape (N, num_total_anchors, 4).
  148. num_total_samples (int): Number os positive samples that is
  149. reduced over all GPUs.
  150. Returns:
  151. dict[str, Tensor]: A dictionary of loss components.
  152. """
  153. anchors = anchors.reshape(-1, 4)
  154. b = cls_score.shape[0]
  155. c = cls_score.shape[1]
  156. h = cls_score.shape[2]
  157. w = cls_score.shape[3]
  158. cls_score = cls_score.permute(0, 2, 3, 1).reshape(
  159. -1, self.cls_out_channels).contiguous()
  160. bbox_pred = bbox_pred.permute(0, 2, 3, 1).reshape(-1, 4)
  161. centerness = centerness.permute(0, 2, 3, 1).reshape(-1)
  162. bbox_targets = bbox_targets.reshape(-1, 4)
  163. labels = labels.reshape(-1)
  164. label_weights = label_weights.reshape(-1)
  165. # classification loss
  166. loss_cls, loss_batch = self.loss_cls(
  167. cls_score, labels, label_weights, avg_factor=num_total_samples)
  168. loss_batch = loss_batch.reshape(b, h, w, c)
  169. loss_batch = loss_batch.sum(3).sum(2).sum(1)
  170. # FG cat_id: [0, num_classes -1], BG cat_id: num_classes
  171. bg_class_ind = self.num_classes
  172. pos_inds = ((labels >= 0)
  173. & (labels < bg_class_ind)).nonzero().squeeze(1)
  174. if len(pos_inds) > 0:
  175. pos_bbox_targets = bbox_targets[pos_inds]
  176. pos_bbox_pred = bbox_pred[pos_inds]
  177. pos_anchors = anchors[pos_inds]
  178. pos_centerness = centerness[pos_inds]
  179. centerness_targets = self.centerness_target(
  180. pos_anchors, pos_bbox_targets)
  181. pos_decode_bbox_pred = self.bbox_coder.decode(
  182. pos_anchors, pos_bbox_pred)
  183. # regression loss
  184. loss_bbox = self.loss_bbox(
  185. pos_decode_bbox_pred,
  186. pos_bbox_targets,
  187. weight=centerness_targets,
  188. avg_factor=1.0)
  189. # centerness loss
  190. loss_centerness = self.loss_centerness(
  191. pos_centerness,
  192. centerness_targets,
  193. avg_factor=num_total_samples)
  194. else:
  195. loss_bbox = bbox_pred.sum() * 0
  196. loss_centerness = centerness.sum() * 0
  197. centerness_targets = bbox_targets.new_tensor(0.)
  198. return loss_cls, loss_batch, loss_bbox, loss_centerness, centerness_targets.sum()
  199. @force_fp32(apply_to=('cls_scores', 'bbox_preds', 'centernesses'))
  200. def loss(self,
  201. cls_scores,
  202. bbox_preds,
  203. centernesses,
  204. gt_bboxes,
  205. gt_labels,
  206. img_metas,
  207. gt_bboxes_ignore=None):
  208. """Compute losses of the head.
  209. Args:
  210. cls_scores (list[Tensor]): Box scores for each scale level
  211. Has shape (N, num_anchors * num_classes, H, W)
  212. bbox_preds (list[Tensor]): Box energies / deltas for each scale
  213. level with shape (N, num_anchors * 4, H, W)
  214. centernesses (list[Tensor]): Centerness for each scale
  215. level with shape (N, num_anchors * 1, H, W)
  216. gt_bboxes (list[Tensor]): Ground truth bboxes for each image with
  217. shape (num_gts, 4) in [tl_x, tl_y, br_x, br_y] format.
  218. gt_labels (list[Tensor]): class indices corresponding to each box
  219. img_metas (list[dict]): Meta information of each image, e.g.,
  220. image size, scaling factor, etc.
  221. gt_bboxes_ignore (list[Tensor] | None): specify which bounding
  222. boxes can be ignored when computing the loss.
  223. Returns:
  224. dict[str, Tensor]: A dictionary of loss components.
  225. """
  226. featmap_sizes = [featmap.size()[-2:] for featmap in cls_scores]
  227. assert len(featmap_sizes) == self.prior_generator.num_levels
  228. device = cls_scores[0].device
  229. anchor_list, valid_flag_list = self.get_anchors(
  230. featmap_sizes, img_metas, device=device)
  231. label_channels = self.cls_out_channels if self.use_sigmoid_cls else 1
  232. cls_reg_targets = self.get_targets(
  233. anchor_list,
  234. valid_flag_list,
  235. gt_bboxes,
  236. img_metas,
  237. gt_bboxes_ignore_list=gt_bboxes_ignore,
  238. gt_labels_list=gt_labels,
  239. label_channels=label_channels)
  240. if cls_reg_targets is None:
  241. return None
  242. (anchor_list, labels_list, label_weights_list, bbox_targets_list,
  243. bbox_weights_list, num_total_pos, num_total_neg) = cls_reg_targets
  244. num_total_samples = reduce_mean(
  245. torch.tensor(num_total_pos, dtype=torch.float,
  246. device=device)).item()
  247. num_total_samples = max(num_total_samples, 1.0)
  248. losses_cls, loss_batch, losses_bbox, loss_centerness,\
  249. bbox_avg_factor = multi_apply(
  250. self.loss_single,
  251. anchor_list,
  252. cls_scores,
  253. bbox_preds,
  254. centernesses,
  255. labels_list,
  256. label_weights_list,
  257. bbox_targets_list,
  258. num_total_samples=num_total_samples)
  259. bbox_avg_factor = sum(bbox_avg_factor)
  260. bbox_avg_factor = reduce_mean(bbox_avg_factor).clamp_(min=1).item()
  261. losses_bbox = list(map(lambda x: x / bbox_avg_factor, losses_bbox))
  262. return dict(
  263. loss_cls=losses_cls,
  264. loss_bbox=losses_bbox,
  265. loss_centerness=loss_centerness,
  266. loss_batch=loss_batch)
  267. def centerness_target(self, anchors, gts):
  268. # only calculate pos centerness targets, otherwise there may be nan
  269. anchors_cx = (anchors[:, 2] + anchors[:, 0]) / 2
  270. anchors_cy = (anchors[:, 3] + anchors[:, 1]) / 2
  271. l_ = anchors_cx - gts[:, 0]
  272. t_ = anchors_cy - gts[:, 1]
  273. r_ = gts[:, 2] - anchors_cx
  274. b_ = gts[:, 3] - anchors_cy
  275. left_right = torch.stack([l_, r_], dim=1)
  276. top_bottom = torch.stack([t_, b_], dim=1)
  277. centerness = torch.sqrt(
  278. (left_right.min(dim=-1)[0] / left_right.max(dim=-1)[0]) *
  279. (top_bottom.min(dim=-1)[0] / top_bottom.max(dim=-1)[0]))
  280. assert not torch.isnan(centerness).any()
  281. return centerness
  282. def get_targets(self,
  283. anchor_list,
  284. valid_flag_list,
  285. gt_bboxes_list,
  286. img_metas,
  287. gt_bboxes_ignore_list=None,
  288. gt_labels_list=None,
  289. label_channels=1,
  290. unmap_outputs=True):
  291. """Get targets for ATSS head.
  292. This method is almost the same as `AnchorHead.get_targets()`. Besides
  293. returning the targets as the parent method does, it also returns the
  294. anchors as the first element of the returned tuple.
  295. """
  296. num_imgs = len(img_metas)
  297. assert len(anchor_list) == len(valid_flag_list) == num_imgs
  298. # anchor number of multi levels
  299. num_level_anchors = [anchors.size(0) for anchors in anchor_list[0]]
  300. num_level_anchors_list = [num_level_anchors] * num_imgs
  301. # concat all level anchors and flags to a single tensor
  302. for i in range(num_imgs):
  303. assert len(anchor_list[i]) == len(valid_flag_list[i])
  304. anchor_list[i] = torch.cat(anchor_list[i])
  305. valid_flag_list[i] = torch.cat(valid_flag_list[i])
  306. # compute targets for each image
  307. if gt_bboxes_ignore_list is None:
  308. gt_bboxes_ignore_list = [None for _ in range(num_imgs)]
  309. if gt_labels_list is None:
  310. gt_labels_list = [None for _ in range(num_imgs)]
  311. (all_anchors, all_labels, all_label_weights, all_bbox_targets,
  312. all_bbox_weights, pos_inds_list, neg_inds_list) = multi_apply(
  313. self._get_target_single,
  314. anchor_list,
  315. valid_flag_list,
  316. num_level_anchors_list,
  317. gt_bboxes_list,
  318. gt_bboxes_ignore_list,
  319. gt_labels_list,
  320. img_metas,
  321. label_channels=label_channels,
  322. unmap_outputs=unmap_outputs)
  323. # no valid anchors
  324. if any([labels is None for labels in all_labels]):
  325. return None
  326. # sampled anchors of all images
  327. num_total_pos = sum([max(inds.numel(), 1) for inds in pos_inds_list])
  328. num_total_neg = sum([max(inds.numel(), 1) for inds in neg_inds_list])
  329. # split targets to a list w.r.t. multiple levels
  330. anchors_list = images_to_levels(all_anchors, num_level_anchors)
  331. labels_list = images_to_levels(all_labels, num_level_anchors)
  332. label_weights_list = images_to_levels(all_label_weights,
  333. num_level_anchors)
  334. bbox_targets_list = images_to_levels(all_bbox_targets,
  335. num_level_anchors)
  336. bbox_weights_list = images_to_levels(all_bbox_weights,
  337. num_level_anchors)
  338. return (anchors_list, labels_list, label_weights_list,
  339. bbox_targets_list, bbox_weights_list, num_total_pos,
  340. num_total_neg)
  341. def _get_target_single(self,
  342. flat_anchors,
  343. valid_flags,
  344. num_level_anchors,
  345. gt_bboxes,
  346. gt_bboxes_ignore,
  347. gt_labels,
  348. img_meta,
  349. label_channels=1,
  350. unmap_outputs=True):
  351. """Compute regression, classification targets for anchors in a single
  352. image.
  353. Args:
  354. flat_anchors (Tensor): Multi-level anchors of the image, which are
  355. concatenated into a single tensor of shape (num_anchors ,4)
  356. valid_flags (Tensor): Multi level valid flags of the image,
  357. which are concatenated into a single tensor of
  358. shape (num_anchors,).
  359. num_level_anchors Tensor): Number of anchors of each scale level.
  360. gt_bboxes (Tensor): Ground truth bboxes of the image,
  361. shape (num_gts, 4).
  362. gt_bboxes_ignore (Tensor): Ground truth bboxes to be
  363. ignored, shape (num_ignored_gts, 4).
  364. gt_labels (Tensor): Ground truth labels of each box,
  365. shape (num_gts,).
  366. img_meta (dict): Meta info of the image.
  367. label_channels (int): Channel of label.
  368. unmap_outputs (bool): Whether to map outputs back to the original
  369. set of anchors.
  370. Returns:
  371. tuple: N is the number of total anchors in the image.
  372. labels (Tensor): Labels of all anchors in the image with shape
  373. (N,).
  374. label_weights (Tensor): Label weights of all anchor in the
  375. image with shape (N,).
  376. bbox_targets (Tensor): BBox targets of all anchors in the
  377. image with shape (N, 4).
  378. bbox_weights (Tensor): BBox weights of all anchors in the
  379. image with shape (N, 4)
  380. pos_inds (Tensor): Indices of positive anchor with shape
  381. (num_pos,).
  382. neg_inds (Tensor): Indices of negative anchor with shape
  383. (num_neg,).
  384. """
  385. inside_flags = anchor_inside_flags(flat_anchors, valid_flags,
  386. img_meta['img_shape'][:2],
  387. self.train_cfg.allowed_border)
  388. if not inside_flags.any():
  389. return (None, ) * 7
  390. # assign gt and sample anchors
  391. anchors = flat_anchors[inside_flags, :]
  392. num_level_anchors_inside = self.get_num_level_anchors_inside(
  393. num_level_anchors, inside_flags)
  394. assign_result = self.assigner.assign(anchors, num_level_anchors_inside,
  395. gt_bboxes, gt_bboxes_ignore,
  396. gt_labels)
  397. sampling_result = self.sampler.sample(assign_result, anchors,
  398. gt_bboxes)
  399. num_valid_anchors = anchors.shape[0]
  400. bbox_targets = torch.zeros_like(anchors)
  401. bbox_weights = torch.zeros_like(anchors)
  402. labels = anchors.new_full((num_valid_anchors, ),
  403. self.num_classes,
  404. dtype=torch.long)
  405. label_weights = anchors.new_zeros(num_valid_anchors, dtype=torch.float)
  406. pos_inds = sampling_result.pos_inds
  407. neg_inds = sampling_result.neg_inds
  408. if len(pos_inds)+len(neg_inds)<num_valid_anchors:
  409. print("error")
  410. if len(pos_inds) > 0:
  411. if self.reg_decoded_bbox:
  412. pos_bbox_targets = sampling_result.pos_gt_bboxes
  413. else:
  414. pos_bbox_targets = self.bbox_coder.encode(
  415. sampling_result.pos_bboxes, sampling_result.pos_gt_bboxes)
  416. bbox_targets[pos_inds, :] = pos_bbox_targets
  417. bbox_weights[pos_inds, :] = 1.0
  418. if gt_labels is None:
  419. # Only rpn gives gt_labels as None
  420. # Foreground is the first class since v2.5.0
  421. labels[pos_inds] = 0
  422. else:
  423. labels[pos_inds] = gt_labels[
  424. sampling_result.pos_assigned_gt_inds]
  425. if self.train_cfg.pos_weight <= 0:
  426. label_weights[pos_inds] = 1.0
  427. else:
  428. label_weights[pos_inds] = self.train_cfg.pos_weight
  429. if len(neg_inds) > 0:
  430. label_weights[neg_inds] = 1.0
  431. # map up to original set of anchors
  432. if unmap_outputs:
  433. num_total_anchors = flat_anchors.size(0)
  434. anchors = unmap(anchors, num_total_anchors, inside_flags)
  435. labels = unmap(
  436. labels, num_total_anchors, inside_flags, fill=self.num_classes)
  437. label_weights = unmap(label_weights, num_total_anchors,
  438. inside_flags)
  439. bbox_targets = unmap(bbox_targets, num_total_anchors, inside_flags)
  440. bbox_weights = unmap(bbox_weights, num_total_anchors, inside_flags)
  441. return (anchors, labels, label_weights, bbox_targets, bbox_weights,
  442. pos_inds, neg_inds)
  443. def get_num_level_anchors_inside(self, num_level_anchors, inside_flags):
  444. split_inside_flags = torch.split(inside_flags, num_level_anchors)
  445. num_level_anchors_inside = [
  446. int(flags.sum()) for flags in split_inside_flags
  447. ]
  448. return num_level_anchors_inside

No Description

Contributors (1)