You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

deformable_detr_head.py 14 kB

2 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318
  1. # Copyright (c) OpenMMLab. All rights reserved.
  2. import copy
  3. import torch
  4. import torch.nn as nn
  5. import torch.nn.functional as F
  6. from mmcv.cnn import Linear, bias_init_with_prob, constant_init
  7. from mmcv.runner import force_fp32
  8. from mmdet.core import multi_apply
  9. from mmdet.models.utils.transformer import inverse_sigmoid
  10. from ..builder import HEADS
  11. from .detr_head import DETRHead
  12. @HEADS.register_module()
  13. class DeformableDETRHead(DETRHead):
  14. """Head of DeformDETR: Deformable DETR: Deformable Transformers for End-to-
  15. End Object Detection.
  16. Code is modified from the `official github repo
  17. <https://github.com/fundamentalvision/Deformable-DETR>`_.
  18. More details can be found in the `paper
  19. <https://arxiv.org/abs/2010.04159>`_ .
  20. Args:
  21. with_box_refine (bool): Whether to refine the reference points
  22. in the decoder. Defaults to False.
  23. as_two_stage (bool) : Whether to generate the proposal from
  24. the outputs of encoder.
  25. transformer (obj:`ConfigDict`): ConfigDict is used for building
  26. the Encoder and Decoder.
  27. """
  28. def __init__(self,
  29. *args,
  30. with_box_refine=False,
  31. as_two_stage=False,
  32. transformer=None,
  33. **kwargs):
  34. self.with_box_refine = with_box_refine
  35. self.as_two_stage = as_two_stage
  36. if self.as_two_stage:
  37. transformer['as_two_stage'] = self.as_two_stage
  38. super(DeformableDETRHead, self).__init__(
  39. *args, transformer=transformer, **kwargs)
  40. def _init_layers(self):
  41. """Initialize classification branch and regression branch of head."""
  42. fc_cls = Linear(self.embed_dims, self.cls_out_channels)
  43. reg_branch = []
  44. for _ in range(self.num_reg_fcs):
  45. reg_branch.append(Linear(self.embed_dims, self.embed_dims))
  46. reg_branch.append(nn.ReLU())
  47. reg_branch.append(Linear(self.embed_dims, 4))
  48. reg_branch = nn.Sequential(*reg_branch)
  49. def _get_clones(module, N):
  50. return nn.ModuleList([copy.deepcopy(module) for i in range(N)])
  51. # last reg_branch is used to generate proposal from
  52. # encode feature map when as_two_stage is True.
  53. num_pred = (self.transformer.decoder.num_layers + 1) if \
  54. self.as_two_stage else self.transformer.decoder.num_layers
  55. if self.with_box_refine:
  56. self.cls_branches = _get_clones(fc_cls, num_pred)
  57. self.reg_branches = _get_clones(reg_branch, num_pred)
  58. else:
  59. self.cls_branches = nn.ModuleList(
  60. [fc_cls for _ in range(num_pred)])
  61. self.reg_branches = nn.ModuleList(
  62. [reg_branch for _ in range(num_pred)])
  63. if not self.as_two_stage:
  64. self.query_embedding = nn.Embedding(self.num_query,
  65. self.embed_dims * 2)
  66. def init_weights(self):
  67. """Initialize weights of the DeformDETR head."""
  68. self.transformer.init_weights()
  69. if self.loss_cls.use_sigmoid:
  70. bias_init = bias_init_with_prob(0.01)
  71. for m in self.cls_branches:
  72. nn.init.constant_(m.bias, bias_init)
  73. for m in self.reg_branches:
  74. constant_init(m[-1], 0, bias=0)
  75. nn.init.constant_(self.reg_branches[0][-1].bias.data[2:], -2.0)
  76. if self.as_two_stage:
  77. for m in self.reg_branches:
  78. nn.init.constant_(m[-1].bias.data[2:], 0.0)
  79. def forward(self, mlvl_feats, img_metas):
  80. """Forward function.
  81. Args:
  82. mlvl_feats (tuple[Tensor]): Features from the upstream
  83. network, each is a 4D-tensor with shape
  84. (N, C, H, W).
  85. img_metas (list[dict]): List of image information.
  86. Returns:
  87. all_cls_scores (Tensor): Outputs from the classification head, \
  88. shape [nb_dec, bs, num_query, cls_out_channels]. Note \
  89. cls_out_channels should includes background.
  90. all_bbox_preds (Tensor): Sigmoid outputs from the regression \
  91. head with normalized coordinate format (cx, cy, w, h). \
  92. Shape [nb_dec, bs, num_query, 4].
  93. enc_outputs_class (Tensor): The score of each point on encode \
  94. feature map, has shape (N, h*w, num_class). Only when \
  95. as_two_stage is True it would be returned, otherwise \
  96. `None` would be returned.
  97. enc_outputs_coord (Tensor): The proposal generate from the \
  98. encode feature map, has shape (N, h*w, 4). Only when \
  99. as_two_stage is True it would be returned, otherwise \
  100. `None` would be returned.
  101. """
  102. batch_size = mlvl_feats[0].size(0)
  103. input_img_h, input_img_w = img_metas[0]['batch_input_shape']
  104. img_masks = mlvl_feats[0].new_ones(
  105. (batch_size, input_img_h, input_img_w))
  106. for img_id in range(batch_size):
  107. img_h, img_w, _ = img_metas[img_id]['img_shape']
  108. img_masks[img_id, :img_h, :img_w] = 0
  109. mlvl_masks = []
  110. mlvl_positional_encodings = []
  111. for feat in mlvl_feats:
  112. mlvl_masks.append(
  113. F.interpolate(img_masks[None],
  114. size=feat.shape[-2:]).to(torch.bool).squeeze(0))
  115. mlvl_positional_encodings.append(
  116. self.positional_encoding(mlvl_masks[-1]))
  117. query_embeds = None
  118. if not self.as_two_stage:
  119. query_embeds = self.query_embedding.weight
  120. hs, init_reference, inter_references, \
  121. enc_outputs_class, enc_outputs_coord = self.transformer(
  122. mlvl_feats,
  123. mlvl_masks,
  124. query_embeds,
  125. mlvl_positional_encodings,
  126. reg_branches=self.reg_branches if self.with_box_refine else None, # noqa:E501
  127. cls_branches=self.cls_branches if self.as_two_stage else None # noqa:E501
  128. )
  129. hs = hs.permute(0, 2, 1, 3)
  130. outputs_classes = []
  131. outputs_coords = []
  132. for lvl in range(hs.shape[0]):
  133. if lvl == 0:
  134. reference = init_reference
  135. else:
  136. reference = inter_references[lvl - 1]
  137. reference = inverse_sigmoid(reference)
  138. outputs_class = self.cls_branches[lvl](hs[lvl])
  139. tmp = self.reg_branches[lvl](hs[lvl])
  140. if reference.shape[-1] == 4:
  141. tmp += reference
  142. else:
  143. assert reference.shape[-1] == 2
  144. tmp[..., :2] += reference
  145. outputs_coord = tmp.sigmoid()
  146. outputs_classes.append(outputs_class)
  147. outputs_coords.append(outputs_coord)
  148. outputs_classes = torch.stack(outputs_classes)
  149. outputs_coords = torch.stack(outputs_coords)
  150. if self.as_two_stage:
  151. return outputs_classes, outputs_coords, \
  152. enc_outputs_class, \
  153. enc_outputs_coord.sigmoid()
  154. else:
  155. return outputs_classes, outputs_coords, \
  156. None, None
  157. @force_fp32(apply_to=('all_cls_scores_list', 'all_bbox_preds_list'))
  158. def loss(self,
  159. all_cls_scores,
  160. all_bbox_preds,
  161. enc_cls_scores,
  162. enc_bbox_preds,
  163. gt_bboxes_list,
  164. gt_labels_list,
  165. img_metas,
  166. gt_bboxes_ignore=None):
  167. """"Loss function.
  168. Args:
  169. all_cls_scores (Tensor): Classification score of all
  170. decoder layers, has shape
  171. [nb_dec, bs, num_query, cls_out_channels].
  172. all_bbox_preds (Tensor): Sigmoid regression
  173. outputs of all decode layers. Each is a 4D-tensor with
  174. normalized coordinate format (cx, cy, w, h) and shape
  175. [nb_dec, bs, num_query, 4].
  176. enc_cls_scores (Tensor): Classification scores of
  177. points on encode feature map , has shape
  178. (N, h*w, num_classes). Only be passed when as_two_stage is
  179. True, otherwise is None.
  180. enc_bbox_preds (Tensor): Regression results of each points
  181. on the encode feature map, has shape (N, h*w, 4). Only be
  182. passed when as_two_stage is True, otherwise is None.
  183. gt_bboxes_list (list[Tensor]): Ground truth bboxes for each image
  184. with shape (num_gts, 4) in [tl_x, tl_y, br_x, br_y] format.
  185. gt_labels_list (list[Tensor]): Ground truth class indices for each
  186. image with shape (num_gts, ).
  187. img_metas (list[dict]): List of image meta information.
  188. gt_bboxes_ignore (list[Tensor], optional): Bounding boxes
  189. which can be ignored for each image. Default None.
  190. Returns:
  191. dict[str, Tensor]: A dictionary of loss components.
  192. """
  193. assert gt_bboxes_ignore is None, \
  194. f'{self.__class__.__name__} only supports ' \
  195. f'for gt_bboxes_ignore setting to None.'
  196. num_dec_layers = len(all_cls_scores)
  197. all_gt_bboxes_list = [gt_bboxes_list for _ in range(num_dec_layers)]
  198. all_gt_labels_list = [gt_labels_list for _ in range(num_dec_layers)]
  199. all_gt_bboxes_ignore_list = [
  200. gt_bboxes_ignore for _ in range(num_dec_layers)
  201. ]
  202. img_metas_list = [img_metas for _ in range(num_dec_layers)]
  203. losses_cls, losses_bbox, losses_iou = multi_apply(
  204. self.loss_single, all_cls_scores, all_bbox_preds,
  205. all_gt_bboxes_list, all_gt_labels_list, img_metas_list,
  206. all_gt_bboxes_ignore_list)
  207. loss_dict = dict()
  208. # loss of proposal generated from encode feature map.
  209. if enc_cls_scores is not None:
  210. binary_labels_list = [
  211. torch.zeros_like(gt_labels_list[i])
  212. for i in range(len(img_metas))
  213. ]
  214. enc_loss_cls, enc_losses_bbox, enc_losses_iou = \
  215. self.loss_single(enc_cls_scores, enc_bbox_preds,
  216. gt_bboxes_list, binary_labels_list,
  217. img_metas, gt_bboxes_ignore)
  218. loss_dict['enc_loss_cls'] = enc_loss_cls
  219. loss_dict['enc_loss_bbox'] = enc_losses_bbox
  220. loss_dict['enc_loss_iou'] = enc_losses_iou
  221. # loss from the last decoder layer
  222. loss_dict['loss_cls'] = losses_cls[-1]
  223. loss_dict['loss_bbox'] = losses_bbox[-1]
  224. loss_dict['loss_iou'] = losses_iou[-1]
  225. # loss from other decoder layers
  226. num_dec_layer = 0
  227. for loss_cls_i, loss_bbox_i, loss_iou_i in zip(losses_cls[:-1],
  228. losses_bbox[:-1],
  229. losses_iou[:-1]):
  230. loss_dict[f'd{num_dec_layer}.loss_cls'] = loss_cls_i
  231. loss_dict[f'd{num_dec_layer}.loss_bbox'] = loss_bbox_i
  232. loss_dict[f'd{num_dec_layer}.loss_iou'] = loss_iou_i
  233. num_dec_layer += 1
  234. return loss_dict
  235. @force_fp32(apply_to=('all_cls_scores_list', 'all_bbox_preds_list'))
  236. def get_bboxes(self,
  237. all_cls_scores,
  238. all_bbox_preds,
  239. enc_cls_scores,
  240. enc_bbox_preds,
  241. img_metas,
  242. rescale=False):
  243. """Transform network outputs for a batch into bbox predictions.
  244. Args:
  245. all_cls_scores (Tensor): Classification score of all
  246. decoder layers, has shape
  247. [nb_dec, bs, num_query, cls_out_channels].
  248. all_bbox_preds (Tensor): Sigmoid regression
  249. outputs of all decode layers. Each is a 4D-tensor with
  250. normalized coordinate format (cx, cy, w, h) and shape
  251. [nb_dec, bs, num_query, 4].
  252. enc_cls_scores (Tensor): Classification scores of
  253. points on encode feature map , has shape
  254. (N, h*w, num_classes). Only be passed when as_two_stage is
  255. True, otherwise is None.
  256. enc_bbox_preds (Tensor): Regression results of each points
  257. on the encode feature map, has shape (N, h*w, 4). Only be
  258. passed when as_two_stage is True, otherwise is None.
  259. img_metas (list[dict]): Meta information of each image.
  260. rescale (bool, optional): If True, return boxes in original
  261. image space. Default False.
  262. Returns:
  263. list[list[Tensor, Tensor]]: Each item in result_list is 2-tuple. \
  264. The first item is an (n, 5) tensor, where the first 4 columns \
  265. are bounding box positions (tl_x, tl_y, br_x, br_y) and the \
  266. 5-th column is a score between 0 and 1. The second item is a \
  267. (n,) tensor where each item is the predicted class label of \
  268. the corresponding box.
  269. """
  270. cls_scores = all_cls_scores[-1]
  271. bbox_preds = all_bbox_preds[-1]
  272. result_list = []
  273. for img_id in range(len(img_metas)):
  274. cls_score = cls_scores[img_id]
  275. bbox_pred = bbox_preds[img_id]
  276. img_shape = img_metas[img_id]['img_shape']
  277. scale_factor = img_metas[img_id]['scale_factor']
  278. proposals = self._get_bboxes_single(cls_score, bbox_pred,
  279. img_shape, scale_factor,
  280. rescale)
  281. result_list.append(proposals)
  282. return result_list

No Description

Contributors (3)