You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

fcos_head.py 20 kB

2 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453
  1. # Copyright (c) OpenMMLab. All rights reserved.
  2. import warnings
  3. import torch
  4. import torch.nn as nn
  5. import torch.nn.functional as F
  6. from mmcv.cnn import Scale
  7. from mmcv.runner import force_fp32
  8. from mmdet.core import multi_apply, reduce_mean
  9. from ..builder import HEADS, build_loss
  10. from .anchor_free_head import AnchorFreeHead
  11. INF = 1e8
  12. @HEADS.register_module()
  13. class FCOSHead(AnchorFreeHead):
  14. """Anchor-free head used in `FCOS <https://arxiv.org/abs/1904.01355>`_.
  15. The FCOS head does not use anchor boxes. Instead bounding boxes are
  16. predicted at each pixel and a centerness measure is used to suppress
  17. low-quality predictions.
  18. Here norm_on_bbox, centerness_on_reg, dcn_on_last_conv are training
  19. tricks used in official repo, which will bring remarkable mAP gains
  20. of up to 4.9. Please see https://github.com/tianzhi0549/FCOS for
  21. more detail.
  22. Args:
  23. num_classes (int): Number of categories excluding the background
  24. category.
  25. in_channels (int): Number of channels in the input feature map.
  26. strides (list[int] | list[tuple[int, int]]): Strides of points
  27. in multiple feature levels. Default: (4, 8, 16, 32, 64).
  28. regress_ranges (tuple[tuple[int, int]]): Regress range of multiple
  29. level points.
  30. center_sampling (bool): If true, use center sampling. Default: False.
  31. center_sample_radius (float): Radius of center sampling. Default: 1.5.
  32. norm_on_bbox (bool): If true, normalize the regression targets
  33. with FPN strides. Default: False.
  34. centerness_on_reg (bool): If true, position centerness on the
  35. regress branch. Please refer to https://github.com/tianzhi0549/FCOS/issues/89#issuecomment-516877042.
  36. Default: False.
  37. conv_bias (bool | str): If specified as `auto`, it will be decided by the
  38. norm_cfg. Bias of conv will be set as True if `norm_cfg` is None, otherwise
  39. False. Default: "auto".
  40. loss_cls (dict): Config of classification loss.
  41. loss_bbox (dict): Config of localization loss.
  42. loss_centerness (dict): Config of centerness loss.
  43. norm_cfg (dict): dictionary to construct and config norm layer.
  44. Default: norm_cfg=dict(type='GN', num_groups=32, requires_grad=True).
  45. init_cfg (dict or list[dict], optional): Initialization config dict.
  46. Example:
  47. >>> self = FCOSHead(11, 7)
  48. >>> feats = [torch.rand(1, 7, s, s) for s in [4, 8, 16, 32, 64]]
  49. >>> cls_score, bbox_pred, centerness = self.forward(feats)
  50. >>> assert len(cls_score) == len(self.scales)
  51. """ # noqa: E501
  52. def __init__(self,
  53. num_classes,
  54. in_channels,
  55. regress_ranges=((-1, 64), (64, 128), (128, 256), (256, 512),
  56. (512, INF)),
  57. center_sampling=False,
  58. center_sample_radius=1.5,
  59. norm_on_bbox=False,
  60. centerness_on_reg=False,
  61. loss_cls=dict(
  62. type='FocalLoss',
  63. use_sigmoid=True,
  64. gamma=2.0,
  65. alpha=0.25,
  66. loss_weight=1.0),
  67. loss_bbox=dict(type='IoULoss', loss_weight=1.0),
  68. loss_centerness=dict(
  69. type='CrossEntropyLoss',
  70. use_sigmoid=True,
  71. loss_weight=1.0),
  72. norm_cfg=dict(type='GN', num_groups=32, requires_grad=True),
  73. init_cfg=dict(
  74. type='Normal',
  75. layer='Conv2d',
  76. std=0.01,
  77. override=dict(
  78. type='Normal',
  79. name='conv_cls',
  80. std=0.01,
  81. bias_prob=0.01)),
  82. **kwargs):
  83. self.regress_ranges = regress_ranges
  84. self.center_sampling = center_sampling
  85. self.center_sample_radius = center_sample_radius
  86. self.norm_on_bbox = norm_on_bbox
  87. self.centerness_on_reg = centerness_on_reg
  88. super().__init__(
  89. num_classes,
  90. in_channels,
  91. loss_cls=loss_cls,
  92. loss_bbox=loss_bbox,
  93. norm_cfg=norm_cfg,
  94. init_cfg=init_cfg,
  95. **kwargs)
  96. self.loss_centerness = build_loss(loss_centerness)
  97. def _init_layers(self):
  98. """Initialize layers of the head."""
  99. super()._init_layers()
  100. self.conv_centerness = nn.Conv2d(self.feat_channels, 1, 3, padding=1)
  101. self.scales = nn.ModuleList([Scale(1.0) for _ in self.strides])
  102. def forward(self, feats):
  103. """Forward features from the upstream network.
  104. Args:
  105. feats (tuple[Tensor]): Features from the upstream network, each is
  106. a 4D-tensor.
  107. Returns:
  108. tuple:
  109. cls_scores (list[Tensor]): Box scores for each scale level, \
  110. each is a 4D-tensor, the channel number is \
  111. num_points * num_classes.
  112. bbox_preds (list[Tensor]): Box energies / deltas for each \
  113. scale level, each is a 4D-tensor, the channel number is \
  114. num_points * 4.
  115. centernesses (list[Tensor]): centerness for each scale level, \
  116. each is a 4D-tensor, the channel number is num_points * 1.
  117. """
  118. return multi_apply(self.forward_single, feats, self.scales,
  119. self.strides)
  120. def forward_single(self, x, scale, stride):
  121. """Forward features of a single scale level.
  122. Args:
  123. x (Tensor): FPN feature maps of the specified stride.
  124. scale (:obj: `mmcv.cnn.Scale`): Learnable scale module to resize
  125. the bbox prediction.
  126. stride (int): The corresponding stride for feature maps, only
  127. used to normalize the bbox prediction when self.norm_on_bbox
  128. is True.
  129. Returns:
  130. tuple: scores for each class, bbox predictions and centerness \
  131. predictions of input feature maps.
  132. """
  133. cls_score, bbox_pred, cls_feat, reg_feat = super().forward_single(x)
  134. if self.centerness_on_reg:
  135. centerness = self.conv_centerness(reg_feat)
  136. else:
  137. centerness = self.conv_centerness(cls_feat)
  138. # scale the bbox_pred of different level
  139. # float to avoid overflow when enabling FP16
  140. bbox_pred = scale(bbox_pred).float()
  141. if self.norm_on_bbox:
  142. bbox_pred = F.relu(bbox_pred)
  143. if not self.training:
  144. bbox_pred *= stride
  145. else:
  146. bbox_pred = bbox_pred.exp()
  147. return cls_score, bbox_pred, centerness
  148. @force_fp32(apply_to=('cls_scores', 'bbox_preds', 'centernesses'))
  149. def loss(self,
  150. cls_scores,
  151. bbox_preds,
  152. centernesses,
  153. gt_bboxes,
  154. gt_labels,
  155. img_metas,
  156. gt_bboxes_ignore=None):
  157. """Compute loss of the head.
  158. Args:
  159. cls_scores (list[Tensor]): Box scores for each scale level,
  160. each is a 4D-tensor, the channel number is
  161. num_points * num_classes.
  162. bbox_preds (list[Tensor]): Box energies / deltas for each scale
  163. level, each is a 4D-tensor, the channel number is
  164. num_points * 4.
  165. centernesses (list[Tensor]): centerness for each scale level, each
  166. is a 4D-tensor, the channel number is num_points * 1.
  167. gt_bboxes (list[Tensor]): Ground truth bboxes for each image with
  168. shape (num_gts, 4) in [tl_x, tl_y, br_x, br_y] format.
  169. gt_labels (list[Tensor]): class indices corresponding to each box
  170. img_metas (list[dict]): Meta information of each image, e.g.,
  171. image size, scaling factor, etc.
  172. gt_bboxes_ignore (None | list[Tensor]): specify which bounding
  173. boxes can be ignored when computing the loss.
  174. Returns:
  175. dict[str, Tensor]: A dictionary of loss components.
  176. """
  177. assert len(cls_scores) == len(bbox_preds) == len(centernesses)
  178. featmap_sizes = [featmap.size()[-2:] for featmap in cls_scores]
  179. all_level_points = self.prior_generator.grid_priors(
  180. featmap_sizes,
  181. dtype=bbox_preds[0].dtype,
  182. device=bbox_preds[0].device)
  183. labels, bbox_targets = self.get_targets(all_level_points, gt_bboxes,
  184. gt_labels)
  185. num_imgs = cls_scores[0].size(0)
  186. # flatten cls_scores, bbox_preds and centerness
  187. flatten_cls_scores = [
  188. cls_score.permute(0, 2, 3, 1).reshape(-1, self.cls_out_channels)
  189. for cls_score in cls_scores
  190. ]
  191. flatten_bbox_preds = [
  192. bbox_pred.permute(0, 2, 3, 1).reshape(-1, 4)
  193. for bbox_pred in bbox_preds
  194. ]
  195. flatten_centerness = [
  196. centerness.permute(0, 2, 3, 1).reshape(-1)
  197. for centerness in centernesses
  198. ]
  199. flatten_cls_scores = torch.cat(flatten_cls_scores)
  200. flatten_bbox_preds = torch.cat(flatten_bbox_preds)
  201. flatten_centerness = torch.cat(flatten_centerness)
  202. flatten_labels = torch.cat(labels)
  203. flatten_bbox_targets = torch.cat(bbox_targets)
  204. # repeat points to align with bbox_preds
  205. flatten_points = torch.cat(
  206. [points.repeat(num_imgs, 1) for points in all_level_points])
  207. # FG cat_id: [0, num_classes -1], BG cat_id: num_classes
  208. bg_class_ind = self.num_classes
  209. pos_inds = ((flatten_labels >= 0)
  210. & (flatten_labels < bg_class_ind)).nonzero().reshape(-1)
  211. num_pos = torch.tensor(
  212. len(pos_inds), dtype=torch.float, device=bbox_preds[0].device)
  213. num_pos = max(reduce_mean(num_pos), 1.0)
  214. loss_cls = self.loss_cls(
  215. flatten_cls_scores, flatten_labels, avg_factor=num_pos)
  216. pos_bbox_preds = flatten_bbox_preds[pos_inds]
  217. pos_centerness = flatten_centerness[pos_inds]
  218. pos_bbox_targets = flatten_bbox_targets[pos_inds]
  219. pos_centerness_targets = self.centerness_target(pos_bbox_targets)
  220. # centerness weighted iou loss
  221. centerness_denorm = max(
  222. reduce_mean(pos_centerness_targets.sum().detach()), 1e-6)
  223. if len(pos_inds) > 0:
  224. pos_points = flatten_points[pos_inds]
  225. pos_decoded_bbox_preds = self.bbox_coder.decode(
  226. pos_points, pos_bbox_preds)
  227. pos_decoded_target_preds = self.bbox_coder.decode(
  228. pos_points, pos_bbox_targets)
  229. loss_bbox = self.loss_bbox(
  230. pos_decoded_bbox_preds,
  231. pos_decoded_target_preds,
  232. weight=pos_centerness_targets,
  233. avg_factor=centerness_denorm)
  234. loss_centerness = self.loss_centerness(
  235. pos_centerness, pos_centerness_targets, avg_factor=num_pos)
  236. else:
  237. loss_bbox = pos_bbox_preds.sum()
  238. loss_centerness = pos_centerness.sum()
  239. return dict(
  240. loss_cls=loss_cls,
  241. loss_bbox=loss_bbox,
  242. loss_centerness=loss_centerness)
  243. def get_targets(self, points, gt_bboxes_list, gt_labels_list):
  244. """Compute regression, classification and centerness targets for points
  245. in multiple images.
  246. Args:
  247. points (list[Tensor]): Points of each fpn level, each has shape
  248. (num_points, 2).
  249. gt_bboxes_list (list[Tensor]): Ground truth bboxes of each image,
  250. each has shape (num_gt, 4).
  251. gt_labels_list (list[Tensor]): Ground truth labels of each box,
  252. each has shape (num_gt,).
  253. Returns:
  254. tuple:
  255. concat_lvl_labels (list[Tensor]): Labels of each level. \
  256. concat_lvl_bbox_targets (list[Tensor]): BBox targets of each \
  257. level.
  258. """
  259. assert len(points) == len(self.regress_ranges)
  260. num_levels = len(points)
  261. # expand regress ranges to align with points
  262. expanded_regress_ranges = [
  263. points[i].new_tensor(self.regress_ranges[i])[None].expand_as(
  264. points[i]) for i in range(num_levels)
  265. ]
  266. # concat all levels points and regress ranges
  267. concat_regress_ranges = torch.cat(expanded_regress_ranges, dim=0)
  268. concat_points = torch.cat(points, dim=0)
  269. # the number of points per img, per lvl
  270. num_points = [center.size(0) for center in points]
  271. # get labels and bbox_targets of each image
  272. labels_list, bbox_targets_list = multi_apply(
  273. self._get_target_single,
  274. gt_bboxes_list,
  275. gt_labels_list,
  276. points=concat_points,
  277. regress_ranges=concat_regress_ranges,
  278. num_points_per_lvl=num_points)
  279. # split to per img, per level
  280. labels_list = [labels.split(num_points, 0) for labels in labels_list]
  281. bbox_targets_list = [
  282. bbox_targets.split(num_points, 0)
  283. for bbox_targets in bbox_targets_list
  284. ]
  285. # concat per level image
  286. concat_lvl_labels = []
  287. concat_lvl_bbox_targets = []
  288. for i in range(num_levels):
  289. concat_lvl_labels.append(
  290. torch.cat([labels[i] for labels in labels_list]))
  291. bbox_targets = torch.cat(
  292. [bbox_targets[i] for bbox_targets in bbox_targets_list])
  293. if self.norm_on_bbox:
  294. bbox_targets = bbox_targets / self.strides[i]
  295. concat_lvl_bbox_targets.append(bbox_targets)
  296. return concat_lvl_labels, concat_lvl_bbox_targets
  297. def _get_target_single(self, gt_bboxes, gt_labels, points, regress_ranges,
  298. num_points_per_lvl):
  299. """Compute regression and classification targets for a single image."""
  300. num_points = points.size(0)
  301. num_gts = gt_labels.size(0)
  302. if num_gts == 0:
  303. return gt_labels.new_full((num_points,), self.num_classes), \
  304. gt_bboxes.new_zeros((num_points, 4))
  305. areas = (gt_bboxes[:, 2] - gt_bboxes[:, 0]) * (
  306. gt_bboxes[:, 3] - gt_bboxes[:, 1])
  307. # TODO: figure out why these two are different
  308. # areas = areas[None].expand(num_points, num_gts)
  309. areas = areas[None].repeat(num_points, 1)
  310. regress_ranges = regress_ranges[:, None, :].expand(
  311. num_points, num_gts, 2)
  312. gt_bboxes = gt_bboxes[None].expand(num_points, num_gts, 4)
  313. xs, ys = points[:, 0], points[:, 1]
  314. xs = xs[:, None].expand(num_points, num_gts)
  315. ys = ys[:, None].expand(num_points, num_gts)
  316. left = xs - gt_bboxes[..., 0]
  317. right = gt_bboxes[..., 2] - xs
  318. top = ys - gt_bboxes[..., 1]
  319. bottom = gt_bboxes[..., 3] - ys
  320. bbox_targets = torch.stack((left, top, right, bottom), -1)
  321. if self.center_sampling:
  322. # condition1: inside a `center bbox`
  323. radius = self.center_sample_radius
  324. center_xs = (gt_bboxes[..., 0] + gt_bboxes[..., 2]) / 2
  325. center_ys = (gt_bboxes[..., 1] + gt_bboxes[..., 3]) / 2
  326. center_gts = torch.zeros_like(gt_bboxes)
  327. stride = center_xs.new_zeros(center_xs.shape)
  328. # project the points on current lvl back to the `original` sizes
  329. lvl_begin = 0
  330. for lvl_idx, num_points_lvl in enumerate(num_points_per_lvl):
  331. lvl_end = lvl_begin + num_points_lvl
  332. stride[lvl_begin:lvl_end] = self.strides[lvl_idx] * radius
  333. lvl_begin = lvl_end
  334. x_mins = center_xs - stride
  335. y_mins = center_ys - stride
  336. x_maxs = center_xs + stride
  337. y_maxs = center_ys + stride
  338. center_gts[..., 0] = torch.where(x_mins > gt_bboxes[..., 0],
  339. x_mins, gt_bboxes[..., 0])
  340. center_gts[..., 1] = torch.where(y_mins > gt_bboxes[..., 1],
  341. y_mins, gt_bboxes[..., 1])
  342. center_gts[..., 2] = torch.where(x_maxs > gt_bboxes[..., 2],
  343. gt_bboxes[..., 2], x_maxs)
  344. center_gts[..., 3] = torch.where(y_maxs > gt_bboxes[..., 3],
  345. gt_bboxes[..., 3], y_maxs)
  346. cb_dist_left = xs - center_gts[..., 0]
  347. cb_dist_right = center_gts[..., 2] - xs
  348. cb_dist_top = ys - center_gts[..., 1]
  349. cb_dist_bottom = center_gts[..., 3] - ys
  350. center_bbox = torch.stack(
  351. (cb_dist_left, cb_dist_top, cb_dist_right, cb_dist_bottom), -1)
  352. inside_gt_bbox_mask = center_bbox.min(-1)[0] > 0
  353. else:
  354. # condition1: inside a gt bbox
  355. inside_gt_bbox_mask = bbox_targets.min(-1)[0] > 0
  356. # condition2: limit the regression range for each location
  357. max_regress_distance = bbox_targets.max(-1)[0]
  358. inside_regress_range = (
  359. (max_regress_distance >= regress_ranges[..., 0])
  360. & (max_regress_distance <= regress_ranges[..., 1]))
  361. # if there are still more than one objects for a location,
  362. # we choose the one with minimal area
  363. areas[inside_gt_bbox_mask == 0] = INF
  364. areas[inside_regress_range == 0] = INF
  365. min_area, min_area_inds = areas.min(dim=1)
  366. labels = gt_labels[min_area_inds]
  367. labels[min_area == INF] = self.num_classes # set as BG
  368. bbox_targets = bbox_targets[range(num_points), min_area_inds]
  369. return labels, bbox_targets
  370. def centerness_target(self, pos_bbox_targets):
  371. """Compute centerness targets.
  372. Args:
  373. pos_bbox_targets (Tensor): BBox targets of positive bboxes in shape
  374. (num_pos, 4)
  375. Returns:
  376. Tensor: Centerness target.
  377. """
  378. # only calculate pos centerness targets, otherwise there may be nan
  379. left_right = pos_bbox_targets[:, [0, 2]]
  380. top_bottom = pos_bbox_targets[:, [1, 3]]
  381. if len(left_right) == 0:
  382. centerness_targets = left_right[..., 0]
  383. else:
  384. centerness_targets = (
  385. left_right.min(dim=-1)[0] / left_right.max(dim=-1)[0]) * (
  386. top_bottom.min(dim=-1)[0] / top_bottom.max(dim=-1)[0])
  387. return torch.sqrt(centerness_targets)
  388. def _get_points_single(self,
  389. featmap_size,
  390. stride,
  391. dtype,
  392. device,
  393. flatten=False):
  394. """Get points according to feature map size.
  395. This function will be deprecated soon.
  396. """
  397. warnings.warn(
  398. '`_get_points_single` in `FCOSHead` will be '
  399. 'deprecated soon, we support a multi level point generator now'
  400. 'you can get points of a single level feature map '
  401. 'with `self.prior_generator.single_level_grid_priors` ')
  402. y, x = super()._get_points_single(featmap_size, stride, dtype, device)
  403. points = torch.stack((x.reshape(-1) * stride, y.reshape(-1) * stride),
  404. dim=-1) + stride // 2
  405. return points

No Description

Contributors (1)