You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

autoassign_head.py 23 kB

2 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524
  1. # Copyright (c) OpenMMLab. All rights reserved.
  2. import warnings
  3. import torch
  4. import torch.nn as nn
  5. import torch.nn.functional as F
  6. from mmcv.cnn import bias_init_with_prob, normal_init
  7. from mmcv.runner import force_fp32
  8. from mmdet.core import multi_apply
  9. from mmdet.core.anchor.point_generator import MlvlPointGenerator
  10. from mmdet.core.bbox import bbox_overlaps
  11. from mmdet.models import HEADS
  12. from mmdet.models.dense_heads.atss_head import reduce_mean
  13. from mmdet.models.dense_heads.fcos_head import FCOSHead
  14. from mmdet.models.dense_heads.paa_head import levels_to_images
  15. EPS = 1e-12
  16. class CenterPrior(nn.Module):
  17. """Center Weighting module to adjust the category-specific prior
  18. distributions.
  19. Args:
  20. force_topk (bool): When no point falls into gt_bbox, forcibly
  21. select the k points closest to the center to calculate
  22. the center prior. Defaults to False.
  23. topk (int): The number of points used to calculate the
  24. center prior when no point falls in gt_bbox. Only work when
  25. force_topk if True. Defaults to 9.
  26. num_classes (int): The class number of dataset. Defaults to 80.
  27. strides (tuple[int]): The stride of each input feature map. Defaults
  28. to (8, 16, 32, 64, 128).
  29. """
  30. def __init__(self,
  31. force_topk=False,
  32. topk=9,
  33. num_classes=80,
  34. strides=(8, 16, 32, 64, 128)):
  35. super(CenterPrior, self).__init__()
  36. self.mean = nn.Parameter(torch.zeros(num_classes, 2))
  37. self.sigma = nn.Parameter(torch.ones(num_classes, 2))
  38. self.strides = strides
  39. self.force_topk = force_topk
  40. self.topk = topk
  41. def forward(self, anchor_points_list, gt_bboxes, labels,
  42. inside_gt_bbox_mask):
  43. """Get the center prior of each point on the feature map for each
  44. instance.
  45. Args:
  46. anchor_points_list (list[Tensor]): list of coordinate
  47. of points on feature map. Each with shape
  48. (num_points, 2).
  49. gt_bboxes (Tensor): The gt_bboxes with shape of
  50. (num_gt, 4).
  51. labels (Tensor): The gt_labels with shape of (num_gt).
  52. inside_gt_bbox_mask (Tensor): Tensor of bool type,
  53. with shape of (num_points, num_gt), each
  54. value is used to mark whether this point falls
  55. within a certain gt.
  56. Returns:
  57. tuple(Tensor):
  58. - center_prior_weights(Tensor): Float tensor with shape \
  59. of (num_points, num_gt). Each value represents \
  60. the center weighting coefficient.
  61. - inside_gt_bbox_mask (Tensor): Tensor of bool type, \
  62. with shape of (num_points, num_gt), each \
  63. value is used to mark whether this point falls \
  64. within a certain gt or is the topk nearest points for \
  65. a specific gt_bbox.
  66. """
  67. inside_gt_bbox_mask = inside_gt_bbox_mask.clone()
  68. num_gts = len(labels)
  69. num_points = sum([len(item) for item in anchor_points_list])
  70. if num_gts == 0:
  71. return gt_bboxes.new_zeros(num_points,
  72. num_gts), inside_gt_bbox_mask
  73. center_prior_list = []
  74. for slvl_points, stride in zip(anchor_points_list, self.strides):
  75. # slvl_points: points from single level in FPN, has shape (h*w, 2)
  76. # single_level_points has shape (h*w, num_gt, 2)
  77. single_level_points = slvl_points[:, None, :].expand(
  78. (slvl_points.size(0), len(gt_bboxes), 2))
  79. gt_center_x = ((gt_bboxes[:, 0] + gt_bboxes[:, 2]) / 2)
  80. gt_center_y = ((gt_bboxes[:, 1] + gt_bboxes[:, 3]) / 2)
  81. gt_center = torch.stack((gt_center_x, gt_center_y), dim=1)
  82. gt_center = gt_center[None]
  83. # instance_center has shape (1, num_gt, 2)
  84. instance_center = self.mean[labels][None]
  85. # instance_sigma has shape (1, num_gt, 2)
  86. instance_sigma = self.sigma[labels][None]
  87. # distance has shape (num_points, num_gt, 2)
  88. distance = (((single_level_points - gt_center) / float(stride) -
  89. instance_center)**2)
  90. center_prior = torch.exp(-distance /
  91. (2 * instance_sigma**2)).prod(dim=-1)
  92. center_prior_list.append(center_prior)
  93. center_prior_weights = torch.cat(center_prior_list, dim=0)
  94. if self.force_topk:
  95. gt_inds_no_points_inside = torch.nonzero(
  96. inside_gt_bbox_mask.sum(0) == 0).reshape(-1)
  97. if gt_inds_no_points_inside.numel():
  98. topk_center_index = \
  99. center_prior_weights[:, gt_inds_no_points_inside].topk(
  100. self.topk,
  101. dim=0)[1]
  102. temp_mask = inside_gt_bbox_mask[:, gt_inds_no_points_inside]
  103. inside_gt_bbox_mask[:, gt_inds_no_points_inside] = \
  104. torch.scatter(temp_mask,
  105. dim=0,
  106. index=topk_center_index,
  107. src=torch.ones_like(
  108. topk_center_index,
  109. dtype=torch.bool))
  110. center_prior_weights[~inside_gt_bbox_mask] = 0
  111. return center_prior_weights, inside_gt_bbox_mask
  112. @HEADS.register_module()
  113. class AutoAssignHead(FCOSHead):
  114. """AutoAssignHead head used in AutoAssign.
  115. More details can be found in the `paper
  116. <https://arxiv.org/abs/2007.03496>`_ .
  117. Args:
  118. force_topk (bool): Used in center prior initialization to
  119. handle extremely small gt. Default is False.
  120. topk (int): The number of points used to calculate the
  121. center prior when no point falls in gt_bbox. Only work when
  122. force_topk if True. Defaults to 9.
  123. pos_loss_weight (float): The loss weight of positive loss
  124. and with default value 0.25.
  125. neg_loss_weight (float): The loss weight of negative loss
  126. and with default value 0.75.
  127. center_loss_weight (float): The loss weight of center prior
  128. loss and with default value 0.75.
  129. """
  130. def __init__(self,
  131. *args,
  132. force_topk=False,
  133. topk=9,
  134. pos_loss_weight=0.25,
  135. neg_loss_weight=0.75,
  136. center_loss_weight=0.75,
  137. **kwargs):
  138. super().__init__(*args, conv_bias=True, **kwargs)
  139. self.center_prior = CenterPrior(
  140. force_topk=force_topk,
  141. topk=topk,
  142. num_classes=self.num_classes,
  143. strides=self.strides)
  144. self.pos_loss_weight = pos_loss_weight
  145. self.neg_loss_weight = neg_loss_weight
  146. self.center_loss_weight = center_loss_weight
  147. self.prior_generator = MlvlPointGenerator(self.strides, offset=0)
  148. def init_weights(self):
  149. """Initialize weights of the head.
  150. In particular, we have special initialization for classified conv's and
  151. regression conv's bias
  152. """
  153. super(AutoAssignHead, self).init_weights()
  154. bias_cls = bias_init_with_prob(0.02)
  155. normal_init(self.conv_cls, std=0.01, bias=bias_cls)
  156. normal_init(self.conv_reg, std=0.01, bias=4.0)
  157. def forward_single(self, x, scale, stride):
  158. """Forward features of a single scale level.
  159. Args:
  160. x (Tensor): FPN feature maps of the specified stride.
  161. scale (:obj: `mmcv.cnn.Scale`): Learnable scale module to resize
  162. the bbox prediction.
  163. stride (int): The corresponding stride for feature maps, only
  164. used to normalize the bbox prediction when self.norm_on_bbox
  165. is True.
  166. Returns:
  167. tuple: scores for each class, bbox predictions and centerness \
  168. predictions of input feature maps.
  169. """
  170. cls_score, bbox_pred, cls_feat, reg_feat = super(
  171. FCOSHead, self).forward_single(x)
  172. centerness = self.conv_centerness(reg_feat)
  173. # scale the bbox_pred of different level
  174. # float to avoid overflow when enabling FP16
  175. bbox_pred = scale(bbox_pred).float()
  176. bbox_pred = F.relu(bbox_pred)
  177. bbox_pred *= stride
  178. return cls_score, bbox_pred, centerness
  179. def get_pos_loss_single(self, cls_score, objectness, reg_loss, gt_labels,
  180. center_prior_weights):
  181. """Calculate the positive loss of all points in gt_bboxes.
  182. Args:
  183. cls_score (Tensor): All category scores for each point on
  184. the feature map. The shape is (num_points, num_class).
  185. objectness (Tensor): Foreground probability of all points,
  186. has shape (num_points, 1).
  187. reg_loss (Tensor): The regression loss of each gt_bbox and each
  188. prediction box, has shape of (num_points, num_gt).
  189. gt_labels (Tensor): The zeros based gt_labels of all gt
  190. with shape of (num_gt,).
  191. center_prior_weights (Tensor): Float tensor with shape
  192. of (num_points, num_gt). Each value represents
  193. the center weighting coefficient.
  194. Returns:
  195. tuple[Tensor]:
  196. - pos_loss (Tensor): The positive loss of all points
  197. in the gt_bboxes.
  198. """
  199. # p_loc: localization confidence
  200. p_loc = torch.exp(-reg_loss)
  201. # p_cls: classification confidence
  202. p_cls = (cls_score * objectness)[:, gt_labels]
  203. # p_pos: joint confidence indicator
  204. p_pos = p_cls * p_loc
  205. # 3 is a hyper-parameter to control the contributions of high and
  206. # low confidence locations towards positive losses.
  207. confidence_weight = torch.exp(p_pos * 3)
  208. p_pos_weight = (confidence_weight * center_prior_weights) / (
  209. (confidence_weight * center_prior_weights).sum(
  210. 0, keepdim=True)).clamp(min=EPS)
  211. reweighted_p_pos = (p_pos * p_pos_weight).sum(0)
  212. pos_loss = F.binary_cross_entropy(
  213. reweighted_p_pos,
  214. torch.ones_like(reweighted_p_pos),
  215. reduction='none')
  216. pos_loss = pos_loss.sum() * self.pos_loss_weight
  217. return pos_loss,
  218. def get_neg_loss_single(self, cls_score, objectness, gt_labels, ious,
  219. inside_gt_bbox_mask):
  220. """Calculate the negative loss of all points in feature map.
  221. Args:
  222. cls_score (Tensor): All category scores for each point on
  223. the feature map. The shape is (num_points, num_class).
  224. objectness (Tensor): Foreground probability of all points
  225. and is shape of (num_points, 1).
  226. gt_labels (Tensor): The zeros based label of all gt with shape of
  227. (num_gt).
  228. ious (Tensor): Float tensor with shape of (num_points, num_gt).
  229. Each value represent the iou of pred_bbox and gt_bboxes.
  230. inside_gt_bbox_mask (Tensor): Tensor of bool type,
  231. with shape of (num_points, num_gt), each
  232. value is used to mark whether this point falls
  233. within a certain gt.
  234. Returns:
  235. tuple[Tensor]:
  236. - neg_loss (Tensor): The negative loss of all points
  237. in the feature map.
  238. """
  239. num_gts = len(gt_labels)
  240. joint_conf = (cls_score * objectness)
  241. p_neg_weight = torch.ones_like(joint_conf)
  242. if num_gts > 0:
  243. # the order of dinmension would affect the value of
  244. # p_neg_weight, we strictly follow the original
  245. # implementation.
  246. inside_gt_bbox_mask = inside_gt_bbox_mask.permute(1, 0)
  247. ious = ious.permute(1, 0)
  248. foreground_idxs = torch.nonzero(inside_gt_bbox_mask, as_tuple=True)
  249. temp_weight = (1 / (1 - ious[foreground_idxs]).clamp_(EPS))
  250. def normalize(x):
  251. return (x - x.min() + EPS) / (x.max() - x.min() + EPS)
  252. for instance_idx in range(num_gts):
  253. idxs = foreground_idxs[0] == instance_idx
  254. if idxs.any():
  255. temp_weight[idxs] = normalize(temp_weight[idxs])
  256. p_neg_weight[foreground_idxs[1],
  257. gt_labels[foreground_idxs[0]]] = 1 - temp_weight
  258. logits = (joint_conf * p_neg_weight)
  259. neg_loss = (
  260. logits**2 * F.binary_cross_entropy(
  261. logits, torch.zeros_like(logits), reduction='none'))
  262. neg_loss = neg_loss.sum() * self.neg_loss_weight
  263. return neg_loss,
  264. @force_fp32(apply_to=('cls_scores', 'bbox_preds', 'objectnesses'))
  265. def loss(self,
  266. cls_scores,
  267. bbox_preds,
  268. objectnesses,
  269. gt_bboxes,
  270. gt_labels,
  271. img_metas,
  272. gt_bboxes_ignore=None):
  273. """Compute loss of the head.
  274. Args:
  275. cls_scores (list[Tensor]): Box scores for each scale level,
  276. each is a 4D-tensor, the channel number is
  277. num_points * num_classes.
  278. bbox_preds (list[Tensor]): Box energies / deltas for each scale
  279. level, each is a 4D-tensor, the channel number is
  280. num_points * 4.
  281. objectnesses (list[Tensor]): objectness for each scale level, each
  282. is a 4D-tensor, the channel number is num_points * 1.
  283. gt_bboxes (list[Tensor]): Ground truth bboxes for each image with
  284. shape (num_gts, 4) in [tl_x, tl_y, br_x, br_y] format.
  285. gt_labels (list[Tensor]): class indices corresponding to each box
  286. img_metas (list[dict]): Meta information of each image, e.g.,
  287. image size, scaling factor, etc.
  288. gt_bboxes_ignore (None | list[Tensor]): specify which bounding
  289. boxes can be ignored when computing the loss.
  290. Returns:
  291. dict[str, Tensor]: A dictionary of loss components.
  292. """
  293. assert len(cls_scores) == len(bbox_preds) == len(objectnesses)
  294. all_num_gt = sum([len(item) for item in gt_bboxes])
  295. featmap_sizes = [featmap.size()[-2:] for featmap in cls_scores]
  296. all_level_points = self.prior_generator.grid_priors(
  297. featmap_sizes,
  298. dtype=bbox_preds[0].dtype,
  299. device=bbox_preds[0].device)
  300. inside_gt_bbox_mask_list, bbox_targets_list = self.get_targets(
  301. all_level_points, gt_bboxes)
  302. center_prior_weight_list = []
  303. temp_inside_gt_bbox_mask_list = []
  304. for gt_bboxe, gt_label, inside_gt_bbox_mask in zip(
  305. gt_bboxes, gt_labels, inside_gt_bbox_mask_list):
  306. center_prior_weight, inside_gt_bbox_mask = \
  307. self.center_prior(all_level_points, gt_bboxe, gt_label,
  308. inside_gt_bbox_mask)
  309. center_prior_weight_list.append(center_prior_weight)
  310. temp_inside_gt_bbox_mask_list.append(inside_gt_bbox_mask)
  311. inside_gt_bbox_mask_list = temp_inside_gt_bbox_mask_list
  312. mlvl_points = torch.cat(all_level_points, dim=0)
  313. bbox_preds = levels_to_images(bbox_preds)
  314. cls_scores = levels_to_images(cls_scores)
  315. objectnesses = levels_to_images(objectnesses)
  316. reg_loss_list = []
  317. ious_list = []
  318. num_points = len(mlvl_points)
  319. for bbox_pred, encoded_targets, inside_gt_bbox_mask in zip(
  320. bbox_preds, bbox_targets_list, inside_gt_bbox_mask_list):
  321. temp_num_gt = encoded_targets.size(1)
  322. expand_mlvl_points = mlvl_points[:, None, :].expand(
  323. num_points, temp_num_gt, 2).reshape(-1, 2)
  324. encoded_targets = encoded_targets.reshape(-1, 4)
  325. expand_bbox_pred = bbox_pred[:, None, :].expand(
  326. num_points, temp_num_gt, 4).reshape(-1, 4)
  327. decoded_bbox_preds = self.bbox_coder.decode(
  328. expand_mlvl_points, expand_bbox_pred)
  329. decoded_target_preds = self.bbox_coder.decode(
  330. expand_mlvl_points, encoded_targets)
  331. with torch.no_grad():
  332. ious = bbox_overlaps(
  333. decoded_bbox_preds, decoded_target_preds, is_aligned=True)
  334. ious = ious.reshape(num_points, temp_num_gt)
  335. if temp_num_gt:
  336. ious = ious.max(
  337. dim=-1, keepdim=True).values.repeat(1, temp_num_gt)
  338. else:
  339. ious = ious.new_zeros(num_points, temp_num_gt)
  340. ious[~inside_gt_bbox_mask] = 0
  341. ious_list.append(ious)
  342. loss_bbox = self.loss_bbox(
  343. decoded_bbox_preds,
  344. decoded_target_preds,
  345. weight=None,
  346. reduction_override='none')
  347. reg_loss_list.append(loss_bbox.reshape(num_points, temp_num_gt))
  348. cls_scores = [item.sigmoid() for item in cls_scores]
  349. objectnesses = [item.sigmoid() for item in objectnesses]
  350. pos_loss_list, = multi_apply(self.get_pos_loss_single, cls_scores,
  351. objectnesses, reg_loss_list, gt_labels,
  352. center_prior_weight_list)
  353. pos_avg_factor = reduce_mean(
  354. bbox_pred.new_tensor(all_num_gt)).clamp_(min=1)
  355. pos_loss = sum(pos_loss_list) / pos_avg_factor
  356. neg_loss_list, = multi_apply(self.get_neg_loss_single, cls_scores,
  357. objectnesses, gt_labels, ious_list,
  358. inside_gt_bbox_mask_list)
  359. neg_avg_factor = sum(item.data.sum()
  360. for item in center_prior_weight_list)
  361. neg_avg_factor = reduce_mean(neg_avg_factor).clamp_(min=1)
  362. neg_loss = sum(neg_loss_list) / neg_avg_factor
  363. center_loss = []
  364. for i in range(len(img_metas)):
  365. if inside_gt_bbox_mask_list[i].any():
  366. center_loss.append(
  367. len(gt_bboxes[i]) /
  368. center_prior_weight_list[i].sum().clamp_(min=EPS))
  369. # when width or height of gt_bbox is smaller than stride of p3
  370. else:
  371. center_loss.append(center_prior_weight_list[i].sum() * 0)
  372. center_loss = torch.stack(center_loss).mean() * self.center_loss_weight
  373. # avoid dead lock in DDP
  374. if all_num_gt == 0:
  375. pos_loss = bbox_preds[0].sum() * 0
  376. dummy_center_prior_loss = self.center_prior.mean.sum(
  377. ) * 0 + self.center_prior.sigma.sum() * 0
  378. center_loss = objectnesses[0].sum() * 0 + dummy_center_prior_loss
  379. loss = dict(
  380. loss_pos=pos_loss, loss_neg=neg_loss, loss_center=center_loss)
  381. return loss
  382. def get_targets(self, points, gt_bboxes_list):
  383. """Compute regression targets and each point inside or outside gt_bbox
  384. in multiple images.
  385. Args:
  386. points (list[Tensor]): Points of all fpn level, each has shape
  387. (num_points, 2).
  388. gt_bboxes_list (list[Tensor]): Ground truth bboxes of each image,
  389. each has shape (num_gt, 4).
  390. Returns:
  391. tuple(list[Tensor]):
  392. - inside_gt_bbox_mask_list (list[Tensor]): Each
  393. Tensor is with bool type and shape of
  394. (num_points, num_gt), each value
  395. is used to mark whether this point falls
  396. within a certain gt.
  397. - concat_lvl_bbox_targets (list[Tensor]): BBox
  398. targets of each level. Each tensor has shape
  399. (num_points, num_gt, 4).
  400. """
  401. concat_points = torch.cat(points, dim=0)
  402. # the number of points per img, per lvl
  403. inside_gt_bbox_mask_list, bbox_targets_list = multi_apply(
  404. self._get_target_single, gt_bboxes_list, points=concat_points)
  405. return inside_gt_bbox_mask_list, bbox_targets_list
  406. def _get_target_single(self, gt_bboxes, points):
  407. """Compute regression targets and each point inside or outside gt_bbox
  408. for a single image.
  409. Args:
  410. gt_bboxes (Tensor): gt_bbox of single image, has shape
  411. (num_gt, 4).
  412. points (Tensor): Points of all fpn level, has shape
  413. (num_points, 2).
  414. Returns:
  415. tuple[Tensor]: Containing the following Tensors:
  416. - inside_gt_bbox_mask (Tensor): Bool tensor with shape
  417. (num_points, num_gt), each value is used to mark
  418. whether this point falls within a certain gt.
  419. - bbox_targets (Tensor): BBox targets of each points with
  420. each gt_bboxes, has shape (num_points, num_gt, 4).
  421. """
  422. num_points = points.size(0)
  423. num_gts = gt_bboxes.size(0)
  424. gt_bboxes = gt_bboxes[None].expand(num_points, num_gts, 4)
  425. xs, ys = points[:, 0], points[:, 1]
  426. xs = xs[:, None]
  427. ys = ys[:, None]
  428. left = xs - gt_bboxes[..., 0]
  429. right = gt_bboxes[..., 2] - xs
  430. top = ys - gt_bboxes[..., 1]
  431. bottom = gt_bboxes[..., 3] - ys
  432. bbox_targets = torch.stack((left, top, right, bottom), -1)
  433. if num_gts:
  434. inside_gt_bbox_mask = bbox_targets.min(-1)[0] > 0
  435. else:
  436. inside_gt_bbox_mask = bbox_targets.new_zeros((num_points, num_gts),
  437. dtype=torch.bool)
  438. return inside_gt_bbox_mask, bbox_targets
  439. def _get_points_single(self,
  440. featmap_size,
  441. stride,
  442. dtype,
  443. device,
  444. flatten=False):
  445. """Almost the same as the implementation in fcos, we remove half stride
  446. offset to align with the original implementation.
  447. This function will be deprecated soon.
  448. """
  449. warnings.warn(
  450. '`_get_points_single` in `AutoAssignHead` will be '
  451. 'deprecated soon, we support a multi level point generator now'
  452. 'you can get points of a single level feature map '
  453. 'with `self.prior_generator.single_level_grid_priors` ')
  454. y, x = super(FCOSHead,
  455. self)._get_points_single(featmap_size, stride, dtype,
  456. device)
  457. points = torch.stack((x.reshape(-1) * stride, y.reshape(-1) * stride),
  458. dim=-1)
  459. return points

No Description

Contributors (3)