You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

detr_head.py 40 kB

2 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844
  1. # Copyright (c) OpenMMLab. All rights reserved.
  2. import torch
  3. import torch.nn as nn
  4. import torch.nn.functional as F
  5. from mmcv.cnn import Conv2d, Linear, build_activation_layer
  6. from mmcv.cnn.bricks.transformer import FFN, build_positional_encoding
  7. from mmcv.runner import force_fp32
  8. from mmdet.core import (bbox_cxcywh_to_xyxy, bbox_xyxy_to_cxcywh,
  9. build_assigner, build_sampler, multi_apply,
  10. reduce_mean)
  11. from mmdet.models.utils import build_transformer
  12. from ..builder import HEADS, build_loss
  13. from .anchor_free_head import AnchorFreeHead
  14. @HEADS.register_module()
  15. class DETRHead(AnchorFreeHead):
  16. """Implements the DETR transformer head.
  17. See `paper: End-to-End Object Detection with Transformers
  18. <https://arxiv.org/pdf/2005.12872>`_ for details.
  19. Args:
  20. num_classes (int): Number of categories excluding the background.
  21. in_channels (int): Number of channels in the input feature map.
  22. num_query (int): Number of query in Transformer.
  23. num_reg_fcs (int, optional): Number of fully-connected layers used in
  24. `FFN`, which is then used for the regression head. Default 2.
  25. transformer (obj:`mmcv.ConfigDict`|dict): Config for transformer.
  26. Default: None.
  27. sync_cls_avg_factor (bool): Whether to sync the avg_factor of
  28. all ranks. Default to False.
  29. positional_encoding (obj:`mmcv.ConfigDict`|dict):
  30. Config for position encoding.
  31. loss_cls (obj:`mmcv.ConfigDict`|dict): Config of the
  32. classification loss. Default `CrossEntropyLoss`.
  33. loss_bbox (obj:`mmcv.ConfigDict`|dict): Config of the
  34. regression loss. Default `L1Loss`.
  35. loss_iou (obj:`mmcv.ConfigDict`|dict): Config of the
  36. regression iou loss. Default `GIoULoss`.
  37. tran_cfg (obj:`mmcv.ConfigDict`|dict): Training config of
  38. transformer head.
  39. test_cfg (obj:`mmcv.ConfigDict`|dict): Testing config of
  40. transformer head.
  41. init_cfg (dict or list[dict], optional): Initialization config dict.
  42. Default: None
  43. """
  44. _version = 2
  45. def __init__(self,
  46. num_classes,
  47. in_channels,
  48. num_query=100,
  49. num_reg_fcs=2,
  50. transformer=None,
  51. sync_cls_avg_factor=False,
  52. positional_encoding=dict(
  53. type='SinePositionalEncoding',
  54. num_feats=128,
  55. normalize=True),
  56. loss_cls=dict(
  57. type='CrossEntropyLoss',
  58. bg_cls_weight=0.1,
  59. use_sigmoid=False,
  60. loss_weight=1.0,
  61. class_weight=1.0),
  62. loss_bbox=dict(type='L1Loss', loss_weight=5.0),
  63. loss_iou=dict(type='GIoULoss', loss_weight=2.0),
  64. train_cfg=dict(
  65. assigner=dict(
  66. type='HungarianAssigner',
  67. cls_cost=dict(type='ClassificationCost', weight=1.),
  68. reg_cost=dict(type='BBoxL1Cost', weight=5.0),
  69. iou_cost=dict(
  70. type='IoUCost', iou_mode='giou', weight=2.0))),
  71. test_cfg=dict(max_per_img=100),
  72. init_cfg=None,
  73. **kwargs):
  74. # NOTE here use `AnchorFreeHead` instead of `TransformerHead`,
  75. # since it brings inconvenience when the initialization of
  76. # `AnchorFreeHead` is called.
  77. super(AnchorFreeHead, self).__init__(init_cfg)
  78. self.bg_cls_weight = 0
  79. self.sync_cls_avg_factor = sync_cls_avg_factor
  80. class_weight = loss_cls.get('class_weight', None)
  81. if class_weight is not None and (self.__class__ is DETRHead):
  82. assert isinstance(class_weight, float), 'Expected ' \
  83. 'class_weight to have type float. Found ' \
  84. f'{type(class_weight)}.'
  85. # NOTE following the official DETR rep0, bg_cls_weight means
  86. # relative classification weight of the no-object class.
  87. bg_cls_weight = loss_cls.get('bg_cls_weight', class_weight)
  88. assert isinstance(bg_cls_weight, float), 'Expected ' \
  89. 'bg_cls_weight to have type float. Found ' \
  90. f'{type(bg_cls_weight)}.'
  91. class_weight = torch.ones(num_classes + 1) * class_weight
  92. # set background class as the last indice
  93. class_weight[num_classes] = bg_cls_weight
  94. loss_cls.update({'class_weight': class_weight})
  95. if 'bg_cls_weight' in loss_cls:
  96. loss_cls.pop('bg_cls_weight')
  97. self.bg_cls_weight = bg_cls_weight
  98. if train_cfg:
  99. assert 'assigner' in train_cfg, 'assigner should be provided '\
  100. 'when train_cfg is set.'
  101. assigner = train_cfg['assigner']
  102. assert loss_cls['loss_weight'] == assigner['cls_cost']['weight'], \
  103. 'The classification weight for loss and matcher should be' \
  104. 'exactly the same.'
  105. assert loss_bbox['loss_weight'] == assigner['reg_cost'][
  106. 'weight'], 'The regression L1 weight for loss and matcher ' \
  107. 'should be exactly the same.'
  108. assert loss_iou['loss_weight'] == assigner['iou_cost']['weight'], \
  109. 'The regression iou weight for loss and matcher should be' \
  110. 'exactly the same.'
  111. self.assigner = build_assigner(assigner)
  112. # DETR sampling=False, so use PseudoSampler
  113. sampler_cfg = dict(type='PseudoSampler')
  114. self.sampler = build_sampler(sampler_cfg, context=self)
  115. self.num_query = num_query
  116. self.num_classes = num_classes
  117. self.in_channels = in_channels
  118. self.num_reg_fcs = num_reg_fcs
  119. self.train_cfg = train_cfg
  120. self.test_cfg = test_cfg
  121. self.fp16_enabled = False
  122. self.loss_cls = build_loss(loss_cls)
  123. self.loss_bbox = build_loss(loss_bbox)
  124. self.loss_iou = build_loss(loss_iou)
  125. if self.loss_cls.use_sigmoid:
  126. self.cls_out_channels = num_classes
  127. else:
  128. self.cls_out_channels = num_classes + 1
  129. self.act_cfg = transformer.get('act_cfg',
  130. dict(type='ReLU', inplace=True))
  131. self.activate = build_activation_layer(self.act_cfg)
  132. self.positional_encoding = build_positional_encoding(
  133. positional_encoding)
  134. self.transformer = build_transformer(transformer)
  135. self.embed_dims = self.transformer.embed_dims
  136. assert 'num_feats' in positional_encoding
  137. num_feats = positional_encoding['num_feats']
  138. assert num_feats * 2 == self.embed_dims, 'embed_dims should' \
  139. f' be exactly 2 times of num_feats. Found {self.embed_dims}' \
  140. f' and {num_feats}.'
  141. self._init_layers()
  142. def _init_layers(self):
  143. """Initialize layers of the transformer head."""
  144. self.input_proj = Conv2d(
  145. self.in_channels, self.embed_dims, kernel_size=1)
  146. self.fc_cls = Linear(self.embed_dims, self.cls_out_channels)
  147. self.reg_ffn = FFN(
  148. self.embed_dims,
  149. self.embed_dims,
  150. self.num_reg_fcs,
  151. self.act_cfg,
  152. dropout=0.0,
  153. add_residual=False)
  154. self.fc_reg = Linear(self.embed_dims, 4)
  155. self.query_embedding = nn.Embedding(self.num_query, self.embed_dims)
  156. def init_weights(self):
  157. """Initialize weights of the transformer head."""
  158. # The initialization for transformer is important
  159. self.transformer.init_weights()
  160. def _load_from_state_dict(self, state_dict, prefix, local_metadata, strict,
  161. missing_keys, unexpected_keys, error_msgs):
  162. """load checkpoints."""
  163. # NOTE here use `AnchorFreeHead` instead of `TransformerHead`,
  164. # since `AnchorFreeHead._load_from_state_dict` should not be
  165. # called here. Invoking the default `Module._load_from_state_dict`
  166. # is enough.
  167. # Names of some parameters in has been changed.
  168. version = local_metadata.get('version', None)
  169. if (version is None or version < 2) and self.__class__ is DETRHead:
  170. convert_dict = {
  171. '.self_attn.': '.attentions.0.',
  172. '.ffn.': '.ffns.0.',
  173. '.multihead_attn.': '.attentions.1.',
  174. '.decoder.norm.': '.decoder.post_norm.'
  175. }
  176. state_dict_keys = list(state_dict.keys())
  177. for k in state_dict_keys:
  178. for ori_key, convert_key in convert_dict.items():
  179. if ori_key in k:
  180. convert_key = k.replace(ori_key, convert_key)
  181. state_dict[convert_key] = state_dict[k]
  182. del state_dict[k]
  183. super(AnchorFreeHead,
  184. self)._load_from_state_dict(state_dict, prefix, local_metadata,
  185. strict, missing_keys,
  186. unexpected_keys, error_msgs)
  187. def forward(self, feats, img_metas):
  188. """Forward function.
  189. Args:
  190. feats (tuple[Tensor]): Features from the upstream network, each is
  191. a 4D-tensor.
  192. img_metas (list[dict]): List of image information.
  193. Returns:
  194. tuple[list[Tensor], list[Tensor]]: Outputs for all scale levels.
  195. - all_cls_scores_list (list[Tensor]): Classification scores \
  196. for each scale level. Each is a 4D-tensor with shape \
  197. [nb_dec, bs, num_query, cls_out_channels]. Note \
  198. `cls_out_channels` should includes background.
  199. - all_bbox_preds_list (list[Tensor]): Sigmoid regression \
  200. outputs for each scale level. Each is a 4D-tensor with \
  201. normalized coordinate format (cx, cy, w, h) and shape \
  202. [nb_dec, bs, num_query, 4].
  203. """
  204. num_levels = len(feats)
  205. img_metas_list = [img_metas for _ in range(num_levels)]
  206. return multi_apply(self.forward_single, feats, img_metas_list)
  207. def forward_single(self, x, img_metas):
  208. """"Forward function for a single feature level.
  209. Args:
  210. x (Tensor): Input feature from backbone's single stage, shape
  211. [bs, c, h, w].
  212. img_metas (list[dict]): List of image information.
  213. Returns:
  214. all_cls_scores (Tensor): Outputs from the classification head,
  215. shape [nb_dec, bs, num_query, cls_out_channels]. Note
  216. cls_out_channels should includes background.
  217. all_bbox_preds (Tensor): Sigmoid outputs from the regression
  218. head with normalized coordinate format (cx, cy, w, h).
  219. Shape [nb_dec, bs, num_query, 4].
  220. """
  221. # construct binary masks which used for the transformer.
  222. # NOTE following the official DETR repo, non-zero values representing
  223. # ignored positions, while zero values means valid positions.
  224. batch_size = x.size(0)
  225. input_img_h, input_img_w = img_metas[0]['batch_input_shape']
  226. masks = x.new_ones((batch_size, input_img_h, input_img_w))
  227. for img_id in range(batch_size):
  228. img_h, img_w, _ = img_metas[img_id]['img_shape']
  229. masks[img_id, :img_h, :img_w] = 0
  230. x = self.input_proj(x)
  231. # interpolate masks to have the same spatial shape with x
  232. masks = F.interpolate(
  233. masks.unsqueeze(1), size=x.shape[-2:]).to(torch.bool).squeeze(1)
  234. # position encoding
  235. pos_embed = self.positional_encoding(masks) # [bs, embed_dim, h, w]
  236. # outs_dec: [nb_dec, bs, num_query, embed_dim]
  237. outs_dec, _ = self.transformer(x, masks, self.query_embedding.weight,
  238. pos_embed)
  239. all_cls_scores = self.fc_cls(outs_dec)
  240. all_bbox_preds = self.fc_reg(self.activate(
  241. self.reg_ffn(outs_dec))).sigmoid()
  242. return all_cls_scores, all_bbox_preds
  243. @force_fp32(apply_to=('all_cls_scores_list', 'all_bbox_preds_list'))
  244. def loss(self,
  245. all_cls_scores_list,
  246. all_bbox_preds_list,
  247. gt_bboxes_list,
  248. gt_labels_list,
  249. img_metas,
  250. gt_bboxes_ignore=None):
  251. """"Loss function.
  252. Only outputs from the last feature level are used for computing
  253. losses by default.
  254. Args:
  255. all_cls_scores_list (list[Tensor]): Classification outputs
  256. for each feature level. Each is a 4D-tensor with shape
  257. [nb_dec, bs, num_query, cls_out_channels].
  258. all_bbox_preds_list (list[Tensor]): Sigmoid regression
  259. outputs for each feature level. Each is a 4D-tensor with
  260. normalized coordinate format (cx, cy, w, h) and shape
  261. [nb_dec, bs, num_query, 4].
  262. gt_bboxes_list (list[Tensor]): Ground truth bboxes for each image
  263. with shape (num_gts, 4) in [tl_x, tl_y, br_x, br_y] format.
  264. gt_labels_list (list[Tensor]): Ground truth class indices for each
  265. image with shape (num_gts, ).
  266. img_metas (list[dict]): List of image meta information.
  267. gt_bboxes_ignore (list[Tensor], optional): Bounding boxes
  268. which can be ignored for each image. Default None.
  269. Returns:
  270. dict[str, Tensor]: A dictionary of loss components.
  271. """
  272. # NOTE defaultly only the outputs from the last feature scale is used.
  273. all_cls_scores = all_cls_scores_list[-1]
  274. all_bbox_preds = all_bbox_preds_list[-1]
  275. assert gt_bboxes_ignore is None, \
  276. 'Only supports for gt_bboxes_ignore setting to None.'
  277. num_dec_layers = len(all_cls_scores)
  278. all_gt_bboxes_list = [gt_bboxes_list for _ in range(num_dec_layers)]
  279. all_gt_labels_list = [gt_labels_list for _ in range(num_dec_layers)]
  280. all_gt_bboxes_ignore_list = [
  281. gt_bboxes_ignore for _ in range(num_dec_layers)
  282. ]
  283. img_metas_list = [img_metas for _ in range(num_dec_layers)]
  284. losses_cls, losses_bbox, losses_iou = multi_apply(
  285. self.loss_single, all_cls_scores, all_bbox_preds,
  286. all_gt_bboxes_list, all_gt_labels_list, img_metas_list,
  287. all_gt_bboxes_ignore_list)
  288. loss_dict = dict()
  289. # loss from the last decoder layer
  290. loss_dict['loss_cls'] = losses_cls[-1]
  291. loss_dict['loss_bbox'] = losses_bbox[-1]
  292. loss_dict['loss_iou'] = losses_iou[-1]
  293. # loss from other decoder layers
  294. num_dec_layer = 0
  295. for loss_cls_i, loss_bbox_i, loss_iou_i in zip(losses_cls[:-1],
  296. losses_bbox[:-1],
  297. losses_iou[:-1]):
  298. loss_dict[f'd{num_dec_layer}.loss_cls'] = loss_cls_i
  299. loss_dict[f'd{num_dec_layer}.loss_bbox'] = loss_bbox_i
  300. loss_dict[f'd{num_dec_layer}.loss_iou'] = loss_iou_i
  301. num_dec_layer += 1
  302. return loss_dict
  303. def loss_single(self,
  304. cls_scores,
  305. bbox_preds,
  306. gt_bboxes_list,
  307. gt_labels_list,
  308. img_metas,
  309. gt_bboxes_ignore_list=None):
  310. """"Loss function for outputs from a single decoder layer of a single
  311. feature level.
  312. Args:
  313. cls_scores (Tensor): Box score logits from a single decoder layer
  314. for all images. Shape [bs, num_query, cls_out_channels].
  315. bbox_preds (Tensor): Sigmoid outputs from a single decoder layer
  316. for all images, with normalized coordinate (cx, cy, w, h) and
  317. shape [bs, num_query, 4].
  318. gt_bboxes_list (list[Tensor]): Ground truth bboxes for each image
  319. with shape (num_gts, 4) in [tl_x, tl_y, br_x, br_y] format.
  320. gt_labels_list (list[Tensor]): Ground truth class indices for each
  321. image with shape (num_gts, ).
  322. img_metas (list[dict]): List of image meta information.
  323. gt_bboxes_ignore_list (list[Tensor], optional): Bounding
  324. boxes which can be ignored for each image. Default None.
  325. Returns:
  326. dict[str, Tensor]: A dictionary of loss components for outputs from
  327. a single decoder layer.
  328. """
  329. num_imgs = cls_scores.size(0)
  330. cls_scores_list = [cls_scores[i] for i in range(num_imgs)]
  331. bbox_preds_list = [bbox_preds[i] for i in range(num_imgs)]
  332. cls_reg_targets = self.get_targets(cls_scores_list, bbox_preds_list,
  333. gt_bboxes_list, gt_labels_list,
  334. img_metas, gt_bboxes_ignore_list)
  335. (labels_list, label_weights_list, bbox_targets_list, bbox_weights_list,
  336. num_total_pos, num_total_neg) = cls_reg_targets
  337. labels = torch.cat(labels_list, 0)
  338. label_weights = torch.cat(label_weights_list, 0)
  339. bbox_targets = torch.cat(bbox_targets_list, 0)
  340. bbox_weights = torch.cat(bbox_weights_list, 0)
  341. # classification loss
  342. cls_scores = cls_scores.reshape(-1, self.cls_out_channels)
  343. # construct weighted avg_factor to match with the official DETR repo
  344. cls_avg_factor = num_total_pos * 1.0 + \
  345. num_total_neg * self.bg_cls_weight
  346. if self.sync_cls_avg_factor:
  347. cls_avg_factor = reduce_mean(
  348. cls_scores.new_tensor([cls_avg_factor]))
  349. cls_avg_factor = max(cls_avg_factor, 1)
  350. loss_cls = self.loss_cls(
  351. cls_scores, labels, label_weights, avg_factor=cls_avg_factor)
  352. # Compute the average number of gt boxes across all gpus, for
  353. # normalization purposes
  354. num_total_pos = loss_cls.new_tensor([num_total_pos])
  355. num_total_pos = torch.clamp(reduce_mean(num_total_pos), min=1).item()
  356. # construct factors used for rescale bboxes
  357. factors = []
  358. for img_meta, bbox_pred in zip(img_metas, bbox_preds):
  359. img_h, img_w, _ = img_meta['img_shape']
  360. factor = bbox_pred.new_tensor([img_w, img_h, img_w,
  361. img_h]).unsqueeze(0).repeat(
  362. bbox_pred.size(0), 1)
  363. factors.append(factor)
  364. factors = torch.cat(factors, 0)
  365. # DETR regress the relative position of boxes (cxcywh) in the image,
  366. # thus the learning target is normalized by the image size. So here
  367. # we need to re-scale them for calculating IoU loss
  368. bbox_preds = bbox_preds.reshape(-1, 4)
  369. bboxes = bbox_cxcywh_to_xyxy(bbox_preds) * factors
  370. bboxes_gt = bbox_cxcywh_to_xyxy(bbox_targets) * factors
  371. # regression IoU loss, defaultly GIoU loss
  372. loss_iou = self.loss_iou(
  373. bboxes, bboxes_gt, bbox_weights, avg_factor=num_total_pos)
  374. # regression L1 loss
  375. loss_bbox = self.loss_bbox(
  376. bbox_preds, bbox_targets, bbox_weights, avg_factor=num_total_pos)
  377. return loss_cls, loss_bbox, loss_iou
  378. def get_targets(self,
  379. cls_scores_list,
  380. bbox_preds_list,
  381. gt_bboxes_list,
  382. gt_labels_list,
  383. img_metas,
  384. gt_bboxes_ignore_list=None):
  385. """"Compute regression and classification targets for a batch image.
  386. Outputs from a single decoder layer of a single feature level are used.
  387. Args:
  388. cls_scores_list (list[Tensor]): Box score logits from a single
  389. decoder layer for each image with shape [num_query,
  390. cls_out_channels].
  391. bbox_preds_list (list[Tensor]): Sigmoid outputs from a single
  392. decoder layer for each image, with normalized coordinate
  393. (cx, cy, w, h) and shape [num_query, 4].
  394. gt_bboxes_list (list[Tensor]): Ground truth bboxes for each image
  395. with shape (num_gts, 4) in [tl_x, tl_y, br_x, br_y] format.
  396. gt_labels_list (list[Tensor]): Ground truth class indices for each
  397. image with shape (num_gts, ).
  398. img_metas (list[dict]): List of image meta information.
  399. gt_bboxes_ignore_list (list[Tensor], optional): Bounding
  400. boxes which can be ignored for each image. Default None.
  401. Returns:
  402. tuple: a tuple containing the following targets.
  403. - labels_list (list[Tensor]): Labels for all images.
  404. - label_weights_list (list[Tensor]): Label weights for all \
  405. images.
  406. - bbox_targets_list (list[Tensor]): BBox targets for all \
  407. images.
  408. - bbox_weights_list (list[Tensor]): BBox weights for all \
  409. images.
  410. - num_total_pos (int): Number of positive samples in all \
  411. images.
  412. - num_total_neg (int): Number of negative samples in all \
  413. images.
  414. """
  415. assert gt_bboxes_ignore_list is None, \
  416. 'Only supports for gt_bboxes_ignore setting to None.'
  417. num_imgs = len(cls_scores_list)
  418. gt_bboxes_ignore_list = [
  419. gt_bboxes_ignore_list for _ in range(num_imgs)
  420. ]
  421. (labels_list, label_weights_list, bbox_targets_list,
  422. bbox_weights_list, pos_inds_list, neg_inds_list) = multi_apply(
  423. self._get_target_single, cls_scores_list, bbox_preds_list,
  424. gt_bboxes_list, gt_labels_list, img_metas, gt_bboxes_ignore_list)
  425. num_total_pos = sum((inds.numel() for inds in pos_inds_list))
  426. num_total_neg = sum((inds.numel() for inds in neg_inds_list))
  427. return (labels_list, label_weights_list, bbox_targets_list,
  428. bbox_weights_list, num_total_pos, num_total_neg)
  429. def _get_target_single(self,
  430. cls_score,
  431. bbox_pred,
  432. gt_bboxes,
  433. gt_labels,
  434. img_meta,
  435. gt_bboxes_ignore=None):
  436. """"Compute regression and classification targets for one image.
  437. Outputs from a single decoder layer of a single feature level are used.
  438. Args:
  439. cls_score (Tensor): Box score logits from a single decoder layer
  440. for one image. Shape [num_query, cls_out_channels].
  441. bbox_pred (Tensor): Sigmoid outputs from a single decoder layer
  442. for one image, with normalized coordinate (cx, cy, w, h) and
  443. shape [num_query, 4].
  444. gt_bboxes (Tensor): Ground truth bboxes for one image with
  445. shape (num_gts, 4) in [tl_x, tl_y, br_x, br_y] format.
  446. gt_labels (Tensor): Ground truth class indices for one image
  447. with shape (num_gts, ).
  448. img_meta (dict): Meta information for one image.
  449. gt_bboxes_ignore (Tensor, optional): Bounding boxes
  450. which can be ignored. Default None.
  451. Returns:
  452. tuple[Tensor]: a tuple containing the following for one image.
  453. - labels (Tensor): Labels of each image.
  454. - label_weights (Tensor]): Label weights of each image.
  455. - bbox_targets (Tensor): BBox targets of each image.
  456. - bbox_weights (Tensor): BBox weights of each image.
  457. - pos_inds (Tensor): Sampled positive indices for each image.
  458. - neg_inds (Tensor): Sampled negative indices for each image.
  459. """
  460. num_bboxes = bbox_pred.size(0)
  461. # assigner and sampler
  462. assign_result = self.assigner.assign(bbox_pred, cls_score, gt_bboxes,
  463. gt_labels, img_meta,
  464. gt_bboxes_ignore)
  465. sampling_result = self.sampler.sample(assign_result, bbox_pred,
  466. gt_bboxes)
  467. pos_inds = sampling_result.pos_inds
  468. neg_inds = sampling_result.neg_inds
  469. # label targets
  470. labels = gt_bboxes.new_full((num_bboxes, ),
  471. self.num_classes,
  472. dtype=torch.long)
  473. labels[pos_inds] = gt_labels[sampling_result.pos_assigned_gt_inds]
  474. label_weights = gt_bboxes.new_ones(num_bboxes)
  475. # bbox targets
  476. bbox_targets = torch.zeros_like(bbox_pred)
  477. bbox_weights = torch.zeros_like(bbox_pred)
  478. bbox_weights[pos_inds] = 1.0
  479. img_h, img_w, _ = img_meta['img_shape']
  480. # DETR regress the relative position of boxes (cxcywh) in the image.
  481. # Thus the learning target should be normalized by the image size, also
  482. # the box format should be converted from defaultly x1y1x2y2 to cxcywh.
  483. factor = bbox_pred.new_tensor([img_w, img_h, img_w,
  484. img_h]).unsqueeze(0)
  485. pos_gt_bboxes_normalized = sampling_result.pos_gt_bboxes / factor
  486. pos_gt_bboxes_targets = bbox_xyxy_to_cxcywh(pos_gt_bboxes_normalized)
  487. bbox_targets[pos_inds] = pos_gt_bboxes_targets
  488. return (labels, label_weights, bbox_targets, bbox_weights, pos_inds,
  489. neg_inds)
  490. # over-write because img_metas are needed as inputs for bbox_head.
  491. def forward_train(self,
  492. x,
  493. img_metas,
  494. gt_bboxes,
  495. gt_labels=None,
  496. gt_bboxes_ignore=None,
  497. proposal_cfg=None,
  498. **kwargs):
  499. """Forward function for training mode.
  500. Args:
  501. x (list[Tensor]): Features from backbone.
  502. img_metas (list[dict]): Meta information of each image, e.g.,
  503. image size, scaling factor, etc.
  504. gt_bboxes (Tensor): Ground truth bboxes of the image,
  505. shape (num_gts, 4).
  506. gt_labels (Tensor): Ground truth labels of each box,
  507. shape (num_gts,).
  508. gt_bboxes_ignore (Tensor): Ground truth bboxes to be
  509. ignored, shape (num_ignored_gts, 4).
  510. proposal_cfg (mmcv.Config): Test / postprocessing configuration,
  511. if None, test_cfg would be used.
  512. Returns:
  513. dict[str, Tensor]: A dictionary of loss components.
  514. """
  515. assert proposal_cfg is None, '"proposal_cfg" must be None'
  516. outs = self(x, img_metas)
  517. if gt_labels is None:
  518. loss_inputs = outs + (gt_bboxes, img_metas)
  519. else:
  520. loss_inputs = outs + (gt_bboxes, gt_labels, img_metas)
  521. losses = self.loss(*loss_inputs, gt_bboxes_ignore=gt_bboxes_ignore)
  522. return losses
  523. @force_fp32(apply_to=('all_cls_scores_list', 'all_bbox_preds_list'))
  524. def get_bboxes(self,
  525. all_cls_scores_list,
  526. all_bbox_preds_list,
  527. img_metas,
  528. rescale=False):
  529. """Transform network outputs for a batch into bbox predictions.
  530. Args:
  531. all_cls_scores_list (list[Tensor]): Classification outputs
  532. for each feature level. Each is a 4D-tensor with shape
  533. [nb_dec, bs, num_query, cls_out_channels].
  534. all_bbox_preds_list (list[Tensor]): Sigmoid regression
  535. outputs for each feature level. Each is a 4D-tensor with
  536. normalized coordinate format (cx, cy, w, h) and shape
  537. [nb_dec, bs, num_query, 4].
  538. img_metas (list[dict]): Meta information of each image.
  539. rescale (bool, optional): If True, return boxes in original
  540. image space. Default False.
  541. Returns:
  542. list[list[Tensor, Tensor]]: Each item in result_list is 2-tuple. \
  543. The first item is an (n, 5) tensor, where the first 4 columns \
  544. are bounding box positions (tl_x, tl_y, br_x, br_y) and the \
  545. 5-th column is a score between 0 and 1. The second item is a \
  546. (n,) tensor where each item is the predicted class label of \
  547. the corresponding box.
  548. """
  549. # NOTE defaultly only using outputs from the last feature level,
  550. # and only the outputs from the last decoder layer is used.
  551. cls_scores = all_cls_scores_list[-1][-1]
  552. bbox_preds = all_bbox_preds_list[-1][-1]
  553. result_list = []
  554. for img_id in range(len(img_metas)):
  555. cls_score = cls_scores[img_id]
  556. bbox_pred = bbox_preds[img_id]
  557. img_shape = img_metas[img_id]['img_shape']
  558. scale_factor = img_metas[img_id]['scale_factor']
  559. proposals = self._get_bboxes_single(cls_score, bbox_pred,
  560. img_shape, scale_factor,
  561. rescale)
  562. result_list.append(proposals)
  563. return result_list
  564. def _get_bboxes_single(self,
  565. cls_score,
  566. bbox_pred,
  567. img_shape,
  568. scale_factor,
  569. rescale=False):
  570. """Transform outputs from the last decoder layer into bbox predictions
  571. for each image.
  572. Args:
  573. cls_score (Tensor): Box score logits from the last decoder layer
  574. for each image. Shape [num_query, cls_out_channels].
  575. bbox_pred (Tensor): Sigmoid outputs from the last decoder layer
  576. for each image, with coordinate format (cx, cy, w, h) and
  577. shape [num_query, 4].
  578. img_shape (tuple[int]): Shape of input image, (height, width, 3).
  579. scale_factor (ndarray, optional): Scale factor of the image arange
  580. as (w_scale, h_scale, w_scale, h_scale).
  581. rescale (bool, optional): If True, return boxes in original image
  582. space. Default False.
  583. Returns:
  584. tuple[Tensor]: Results of detected bboxes and labels.
  585. - det_bboxes: Predicted bboxes with shape [num_query, 5], \
  586. where the first 4 columns are bounding box positions \
  587. (tl_x, tl_y, br_x, br_y) and the 5-th column are scores \
  588. between 0 and 1.
  589. - det_labels: Predicted labels of the corresponding box with \
  590. shape [num_query].
  591. """
  592. assert len(cls_score) == len(bbox_pred)
  593. max_per_img = self.test_cfg.get('max_per_img', self.num_query)
  594. # exclude background
  595. if self.loss_cls.use_sigmoid:
  596. cls_score = cls_score.sigmoid()
  597. scores, indexes = cls_score.view(-1).topk(max_per_img)
  598. det_labels = indexes % self.num_classes
  599. bbox_index = indexes // self.num_classes
  600. bbox_pred = bbox_pred[bbox_index]
  601. else:
  602. scores, det_labels = F.softmax(cls_score, dim=-1)[..., :-1].max(-1)
  603. scores, bbox_index = scores.topk(max_per_img)
  604. bbox_pred = bbox_pred[bbox_index]
  605. det_labels = det_labels[bbox_index]
  606. det_bboxes = bbox_cxcywh_to_xyxy(bbox_pred)
  607. det_bboxes[:, 0::2] = det_bboxes[:, 0::2] * img_shape[1]
  608. det_bboxes[:, 1::2] = det_bboxes[:, 1::2] * img_shape[0]
  609. det_bboxes[:, 0::2].clamp_(min=0, max=img_shape[1])
  610. det_bboxes[:, 1::2].clamp_(min=0, max=img_shape[0])
  611. if rescale:
  612. det_bboxes /= det_bboxes.new_tensor(scale_factor)
  613. det_bboxes = torch.cat((det_bboxes, scores.unsqueeze(1)), -1)
  614. return det_bboxes, det_labels
  615. def simple_test_bboxes(self, feats, img_metas, rescale=False):
  616. """Test det bboxes without test-time augmentation.
  617. Args:
  618. feats (tuple[torch.Tensor]): Multi-level features from the
  619. upstream network, each is a 4D-tensor.
  620. img_metas (list[dict]): List of image information.
  621. rescale (bool, optional): Whether to rescale the results.
  622. Defaults to False.
  623. Returns:
  624. list[tuple[Tensor, Tensor]]: Each item in result_list is 2-tuple.
  625. The first item is ``bboxes`` with shape (n, 5),
  626. where 5 represent (tl_x, tl_y, br_x, br_y, score).
  627. The shape of the second tensor in the tuple is ``labels``
  628. with shape (n,)
  629. """
  630. # forward of this head requires img_metas
  631. outs = self.forward(feats, img_metas)
  632. results_list = self.get_bboxes(*outs, img_metas, rescale=rescale)
  633. return results_list
  634. def forward_onnx(self, feats, img_metas):
  635. """Forward function for exporting to ONNX.
  636. Over-write `forward` because: `masks` is directly created with
  637. zero (valid position tag) and has the same spatial size as `x`.
  638. Thus the construction of `masks` is different from that in `forward`.
  639. Args:
  640. feats (tuple[Tensor]): Features from the upstream network, each is
  641. a 4D-tensor.
  642. img_metas (list[dict]): List of image information.
  643. Returns:
  644. tuple[list[Tensor], list[Tensor]]: Outputs for all scale levels.
  645. - all_cls_scores_list (list[Tensor]): Classification scores \
  646. for each scale level. Each is a 4D-tensor with shape \
  647. [nb_dec, bs, num_query, cls_out_channels]. Note \
  648. `cls_out_channels` should includes background.
  649. - all_bbox_preds_list (list[Tensor]): Sigmoid regression \
  650. outputs for each scale level. Each is a 4D-tensor with \
  651. normalized coordinate format (cx, cy, w, h) and shape \
  652. [nb_dec, bs, num_query, 4].
  653. """
  654. num_levels = len(feats)
  655. img_metas_list = [img_metas for _ in range(num_levels)]
  656. return multi_apply(self.forward_single_onnx, feats, img_metas_list)
  657. def forward_single_onnx(self, x, img_metas):
  658. """"Forward function for a single feature level with ONNX exportation.
  659. Args:
  660. x (Tensor): Input feature from backbone's single stage, shape
  661. [bs, c, h, w].
  662. img_metas (list[dict]): List of image information.
  663. Returns:
  664. all_cls_scores (Tensor): Outputs from the classification head,
  665. shape [nb_dec, bs, num_query, cls_out_channels]. Note
  666. cls_out_channels should includes background.
  667. all_bbox_preds (Tensor): Sigmoid outputs from the regression
  668. head with normalized coordinate format (cx, cy, w, h).
  669. Shape [nb_dec, bs, num_query, 4].
  670. """
  671. # Note `img_shape` is not dynamically traceable to ONNX,
  672. # since the related augmentation was done with numpy under
  673. # CPU. Thus `masks` is directly created with zeros (valid tag)
  674. # and the same spatial shape as `x`.
  675. # The difference between torch and exported ONNX model may be
  676. # ignored, since the same performance is achieved (e.g.
  677. # 40.1 vs 40.1 for DETR)
  678. batch_size = x.size(0)
  679. h, w = x.size()[-2:]
  680. masks = x.new_zeros((batch_size, h, w)) # [B,h,w]
  681. x = self.input_proj(x)
  682. # interpolate masks to have the same spatial shape with x
  683. masks = F.interpolate(
  684. masks.unsqueeze(1), size=x.shape[-2:]).to(torch.bool).squeeze(1)
  685. pos_embed = self.positional_encoding(masks)
  686. outs_dec, _ = self.transformer(x, masks, self.query_embedding.weight,
  687. pos_embed)
  688. all_cls_scores = self.fc_cls(outs_dec)
  689. all_bbox_preds = self.fc_reg(self.activate(
  690. self.reg_ffn(outs_dec))).sigmoid()
  691. return all_cls_scores, all_bbox_preds
  692. def onnx_export(self, all_cls_scores_list, all_bbox_preds_list, img_metas):
  693. """Transform network outputs into bbox predictions, with ONNX
  694. exportation.
  695. Args:
  696. all_cls_scores_list (list[Tensor]): Classification outputs
  697. for each feature level. Each is a 4D-tensor with shape
  698. [nb_dec, bs, num_query, cls_out_channels].
  699. all_bbox_preds_list (list[Tensor]): Sigmoid regression
  700. outputs for each feature level. Each is a 4D-tensor with
  701. normalized coordinate format (cx, cy, w, h) and shape
  702. [nb_dec, bs, num_query, 4].
  703. img_metas (list[dict]): Meta information of each image.
  704. Returns:
  705. tuple[Tensor, Tensor]: dets of shape [N, num_det, 5]
  706. and class labels of shape [N, num_det].
  707. """
  708. assert len(img_metas) == 1, \
  709. 'Only support one input image while in exporting to ONNX'
  710. cls_scores = all_cls_scores_list[-1][-1]
  711. bbox_preds = all_bbox_preds_list[-1][-1]
  712. # Note `img_shape` is not dynamically traceable to ONNX,
  713. # here `img_shape_for_onnx` (padded shape of image tensor)
  714. # is used.
  715. img_shape = img_metas[0]['img_shape_for_onnx']
  716. max_per_img = self.test_cfg.get('max_per_img', self.num_query)
  717. batch_size = cls_scores.size(0)
  718. # `batch_index_offset` is used for the gather of concatenated tensor
  719. batch_index_offset = torch.arange(batch_size).to(
  720. cls_scores.device) * max_per_img
  721. batch_index_offset = batch_index_offset.unsqueeze(1).expand(
  722. batch_size, max_per_img)
  723. # supports dynamical batch inference
  724. if self.loss_cls.use_sigmoid:
  725. cls_scores = cls_scores.sigmoid()
  726. scores, indexes = cls_scores.view(batch_size, -1).topk(
  727. max_per_img, dim=1)
  728. det_labels = indexes % self.num_classes
  729. bbox_index = indexes // self.num_classes
  730. bbox_index = (bbox_index + batch_index_offset).view(-1)
  731. bbox_preds = bbox_preds.view(-1, 4)[bbox_index]
  732. bbox_preds = bbox_preds.view(batch_size, -1, 4)
  733. else:
  734. scores, det_labels = F.softmax(
  735. cls_scores, dim=-1)[..., :-1].max(-1)
  736. scores, bbox_index = scores.topk(max_per_img, dim=1)
  737. bbox_index = (bbox_index + batch_index_offset).view(-1)
  738. bbox_preds = bbox_preds.view(-1, 4)[bbox_index]
  739. det_labels = det_labels.view(-1)[bbox_index]
  740. bbox_preds = bbox_preds.view(batch_size, -1, 4)
  741. det_labels = det_labels.view(batch_size, -1)
  742. det_bboxes = bbox_cxcywh_to_xyxy(bbox_preds)
  743. # use `img_shape_tensor` for dynamically exporting to ONNX
  744. img_shape_tensor = img_shape.flip(0).repeat(2) # [w,h,w,h]
  745. img_shape_tensor = img_shape_tensor.unsqueeze(0).unsqueeze(0).expand(
  746. batch_size, det_bboxes.size(1), 4)
  747. det_bboxes = det_bboxes * img_shape_tensor
  748. # dynamically clip bboxes
  749. x1, y1, x2, y2 = det_bboxes.split((1, 1, 1, 1), dim=-1)
  750. from mmdet.core.export import dynamic_clip_for_onnx
  751. x1, y1, x2, y2 = dynamic_clip_for_onnx(x1, y1, x2, y2, img_shape)
  752. det_bboxes = torch.cat([x1, y1, x2, y2], dim=-1)
  753. det_bboxes = torch.cat((det_bboxes, scores.unsqueeze(-1)), -1)
  754. return det_bboxes, det_labels

No Description

Contributors (3)