You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

dii_head.py 19 kB

2 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425
  1. # Copyright (c) OpenMMLab. All rights reserved.
  2. import torch
  3. import torch.nn as nn
  4. from mmcv.cnn import (bias_init_with_prob, build_activation_layer,
  5. build_norm_layer)
  6. from mmcv.cnn.bricks.transformer import FFN, MultiheadAttention
  7. from mmcv.runner import auto_fp16, force_fp32
  8. from mmdet.core import multi_apply
  9. from mmdet.models.builder import HEADS, build_loss
  10. from mmdet.models.dense_heads.atss_head import reduce_mean
  11. from mmdet.models.losses import accuracy
  12. from mmdet.models.utils import build_transformer
  13. from .bbox_head import BBoxHead
  14. @HEADS.register_module()
  15. class DIIHead(BBoxHead):
  16. r"""Dynamic Instance Interactive Head for `Sparse R-CNN: End-to-End Object
  17. Detection with Learnable Proposals <https://arxiv.org/abs/2011.12450>`_
  18. Args:
  19. num_classes (int): Number of class in dataset.
  20. Defaults to 80.
  21. num_ffn_fcs (int): The number of fully-connected
  22. layers in FFNs. Defaults to 2.
  23. num_heads (int): The hidden dimension of FFNs.
  24. Defaults to 8.
  25. num_cls_fcs (int): The number of fully-connected
  26. layers in classification subnet. Defaults to 1.
  27. num_reg_fcs (int): The number of fully-connected
  28. layers in regression subnet. Defaults to 3.
  29. feedforward_channels (int): The hidden dimension
  30. of FFNs. Defaults to 2048
  31. in_channels (int): Hidden_channels of MultiheadAttention.
  32. Defaults to 256.
  33. dropout (float): Probability of drop the channel.
  34. Defaults to 0.0
  35. ffn_act_cfg (dict): The activation config for FFNs.
  36. dynamic_conv_cfg (dict): The convolution config
  37. for DynamicConv.
  38. loss_iou (dict): The config for iou or giou loss.
  39. """
  40. def __init__(self,
  41. num_classes=80,
  42. num_ffn_fcs=2,
  43. num_heads=8,
  44. num_cls_fcs=1,
  45. num_reg_fcs=3,
  46. feedforward_channels=2048,
  47. in_channels=256,
  48. dropout=0.0,
  49. ffn_act_cfg=dict(type='ReLU', inplace=True),
  50. dynamic_conv_cfg=dict(
  51. type='DynamicConv',
  52. in_channels=256,
  53. feat_channels=64,
  54. out_channels=256,
  55. input_feat_shape=7,
  56. act_cfg=dict(type='ReLU', inplace=True),
  57. norm_cfg=dict(type='LN')),
  58. loss_iou=dict(type='GIoULoss', loss_weight=2.0),
  59. init_cfg=None,
  60. **kwargs):
  61. assert init_cfg is None, 'To prevent abnormal initialization ' \
  62. 'behavior, init_cfg is not allowed to be set'
  63. super(DIIHead, self).__init__(
  64. num_classes=num_classes,
  65. reg_decoded_bbox=True,
  66. reg_class_agnostic=True,
  67. init_cfg=init_cfg,
  68. **kwargs)
  69. self.loss_iou = build_loss(loss_iou)
  70. self.in_channels = in_channels
  71. self.fp16_enabled = False
  72. self.attention = MultiheadAttention(in_channels, num_heads, dropout)
  73. self.attention_norm = build_norm_layer(dict(type='LN'), in_channels)[1]
  74. self.instance_interactive_conv = build_transformer(dynamic_conv_cfg)
  75. self.instance_interactive_conv_dropout = nn.Dropout(dropout)
  76. self.instance_interactive_conv_norm = build_norm_layer(
  77. dict(type='LN'), in_channels)[1]
  78. self.ffn = FFN(
  79. in_channels,
  80. feedforward_channels,
  81. num_ffn_fcs,
  82. act_cfg=ffn_act_cfg,
  83. dropout=dropout)
  84. self.ffn_norm = build_norm_layer(dict(type='LN'), in_channels)[1]
  85. self.cls_fcs = nn.ModuleList()
  86. for _ in range(num_cls_fcs):
  87. self.cls_fcs.append(
  88. nn.Linear(in_channels, in_channels, bias=False))
  89. self.cls_fcs.append(
  90. build_norm_layer(dict(type='LN'), in_channels)[1])
  91. self.cls_fcs.append(
  92. build_activation_layer(dict(type='ReLU', inplace=True)))
  93. # over load the self.fc_cls in BBoxHead
  94. if self.loss_cls.use_sigmoid:
  95. self.fc_cls = nn.Linear(in_channels, self.num_classes)
  96. else:
  97. self.fc_cls = nn.Linear(in_channels, self.num_classes + 1)
  98. self.reg_fcs = nn.ModuleList()
  99. for _ in range(num_reg_fcs):
  100. self.reg_fcs.append(
  101. nn.Linear(in_channels, in_channels, bias=False))
  102. self.reg_fcs.append(
  103. build_norm_layer(dict(type='LN'), in_channels)[1])
  104. self.reg_fcs.append(
  105. build_activation_layer(dict(type='ReLU', inplace=True)))
  106. # over load the self.fc_cls in BBoxHead
  107. self.fc_reg = nn.Linear(in_channels, 4)
  108. assert self.reg_class_agnostic, 'DIIHead only ' \
  109. 'suppport `reg_class_agnostic=True` '
  110. assert self.reg_decoded_bbox, 'DIIHead only ' \
  111. 'suppport `reg_decoded_bbox=True`'
  112. def init_weights(self):
  113. """Use xavier initialization for all weight parameter and set
  114. classification head bias as a specific value when use focal loss."""
  115. super(DIIHead, self).init_weights()
  116. for p in self.parameters():
  117. if p.dim() > 1:
  118. nn.init.xavier_uniform_(p)
  119. else:
  120. # adopt the default initialization for
  121. # the weight and bias of the layer norm
  122. pass
  123. if self.loss_cls.use_sigmoid:
  124. bias_init = bias_init_with_prob(0.01)
  125. nn.init.constant_(self.fc_cls.bias, bias_init)
  126. @auto_fp16()
  127. def forward(self, roi_feat, proposal_feat):
  128. """Forward function of Dynamic Instance Interactive Head.
  129. Args:
  130. roi_feat (Tensor): Roi-pooling features with shape
  131. (batch_size*num_proposals, feature_dimensions,
  132. pooling_h , pooling_w).
  133. proposal_feat (Tensor): Intermediate feature get from
  134. diihead in last stage, has shape
  135. (batch_size, num_proposals, feature_dimensions)
  136. Returns:
  137. tuple[Tensor]: Usually a tuple of classification scores
  138. and bbox prediction and a intermediate feature.
  139. - cls_scores (Tensor): Classification scores for
  140. all proposals, has shape
  141. (batch_size, num_proposals, num_classes).
  142. - bbox_preds (Tensor): Box energies / deltas for
  143. all proposals, has shape
  144. (batch_size, num_proposals, 4).
  145. - obj_feat (Tensor): Object feature before classification
  146. and regression subnet, has shape
  147. (batch_size, num_proposal, feature_dimensions).
  148. """
  149. N, num_proposals = proposal_feat.shape[:2]
  150. # Self attention
  151. proposal_feat = proposal_feat.permute(1, 0, 2)
  152. proposal_feat = self.attention_norm(self.attention(proposal_feat))
  153. attn_feats = proposal_feat.permute(1, 0, 2)
  154. # instance interactive
  155. proposal_feat = attn_feats.reshape(-1, self.in_channels)
  156. proposal_feat_iic = self.instance_interactive_conv(
  157. proposal_feat, roi_feat)
  158. proposal_feat = proposal_feat + self.instance_interactive_conv_dropout(
  159. proposal_feat_iic)
  160. obj_feat = self.instance_interactive_conv_norm(proposal_feat)
  161. # FFN
  162. obj_feat = self.ffn_norm(self.ffn(obj_feat))
  163. cls_feat = obj_feat
  164. reg_feat = obj_feat
  165. for cls_layer in self.cls_fcs:
  166. cls_feat = cls_layer(cls_feat)
  167. for reg_layer in self.reg_fcs:
  168. reg_feat = reg_layer(reg_feat)
  169. cls_score = self.fc_cls(cls_feat).view(
  170. N, num_proposals, self.num_classes
  171. if self.loss_cls.use_sigmoid else self.num_classes + 1)
  172. bbox_delta = self.fc_reg(reg_feat).view(N, num_proposals, 4)
  173. return cls_score, bbox_delta, obj_feat.view(
  174. N, num_proposals, self.in_channels), attn_feats
  175. @force_fp32(apply_to=('cls_score', 'bbox_pred'))
  176. def loss(self,
  177. cls_score,
  178. bbox_pred,
  179. labels,
  180. label_weights,
  181. bbox_targets,
  182. bbox_weights,
  183. imgs_whwh=None,
  184. reduction_override=None,
  185. **kwargs):
  186. """"Loss function of DIIHead, get loss of all images.
  187. Args:
  188. cls_score (Tensor): Classification prediction
  189. results of all class, has shape
  190. (batch_size * num_proposals_single_image, num_classes)
  191. bbox_pred (Tensor): Regression prediction results,
  192. has shape
  193. (batch_size * num_proposals_single_image, 4), the last
  194. dimension 4 represents [tl_x, tl_y, br_x, br_y].
  195. labels (Tensor): Label of each proposals, has shape
  196. (batch_size * num_proposals_single_image
  197. label_weights (Tensor): Classification loss
  198. weight of each proposals, has shape
  199. (batch_size * num_proposals_single_image
  200. bbox_targets (Tensor): Regression targets of each
  201. proposals, has shape
  202. (batch_size * num_proposals_single_image, 4),
  203. the last dimension 4 represents
  204. [tl_x, tl_y, br_x, br_y].
  205. bbox_weights (Tensor): Regression loss weight of each
  206. proposals's coordinate, has shape
  207. (batch_size * num_proposals_single_image, 4),
  208. imgs_whwh (Tensor): imgs_whwh (Tensor): Tensor with\
  209. shape (batch_size, num_proposals, 4), the last
  210. dimension means
  211. [img_width,img_height, img_width, img_height].
  212. reduction_override (str, optional): The reduction
  213. method used to override the original reduction
  214. method of the loss. Options are "none",
  215. "mean" and "sum". Defaults to None,
  216. Returns:
  217. dict[str, Tensor]: Dictionary of loss components
  218. """
  219. losses = dict()
  220. bg_class_ind = self.num_classes
  221. # note in spare rcnn num_gt == num_pos
  222. pos_inds = (labels >= 0) & (labels < bg_class_ind)
  223. num_pos = pos_inds.sum().float()
  224. avg_factor = reduce_mean(num_pos)
  225. if cls_score is not None:
  226. if cls_score.numel() > 0:
  227. losses['loss_cls'] = self.loss_cls(
  228. cls_score,
  229. labels,
  230. label_weights,
  231. avg_factor=avg_factor,
  232. reduction_override=reduction_override)
  233. losses['pos_acc'] = accuracy(cls_score[pos_inds],
  234. labels[pos_inds])
  235. if bbox_pred is not None:
  236. # 0~self.num_classes-1 are FG, self.num_classes is BG
  237. # do not perform bounding box regression for BG anymore.
  238. if pos_inds.any():
  239. pos_bbox_pred = bbox_pred.reshape(bbox_pred.size(0),
  240. 4)[pos_inds.type(torch.bool)]
  241. imgs_whwh = imgs_whwh.reshape(bbox_pred.size(0),
  242. 4)[pos_inds.type(torch.bool)]
  243. losses['loss_bbox'] = self.loss_bbox(
  244. pos_bbox_pred / imgs_whwh,
  245. bbox_targets[pos_inds.type(torch.bool)] / imgs_whwh,
  246. bbox_weights[pos_inds.type(torch.bool)],
  247. avg_factor=avg_factor)
  248. losses['loss_iou'] = self.loss_iou(
  249. pos_bbox_pred,
  250. bbox_targets[pos_inds.type(torch.bool)],
  251. bbox_weights[pos_inds.type(torch.bool)],
  252. avg_factor=avg_factor)
  253. else:
  254. losses['loss_bbox'] = bbox_pred.sum() * 0
  255. losses['loss_iou'] = bbox_pred.sum() * 0
  256. return losses
  257. def _get_target_single(self, pos_inds, neg_inds, pos_bboxes, neg_bboxes,
  258. pos_gt_bboxes, pos_gt_labels, cfg):
  259. """Calculate the ground truth for proposals in the single image
  260. according to the sampling results.
  261. Almost the same as the implementation in `bbox_head`,
  262. we add pos_inds and neg_inds to select positive and
  263. negative samples instead of selecting the first num_pos
  264. as positive samples.
  265. Args:
  266. pos_inds (Tensor): The length is equal to the
  267. positive sample numbers contain all index
  268. of the positive sample in the origin proposal set.
  269. neg_inds (Tensor): The length is equal to the
  270. negative sample numbers contain all index
  271. of the negative sample in the origin proposal set.
  272. pos_bboxes (Tensor): Contains all the positive boxes,
  273. has shape (num_pos, 4), the last dimension 4
  274. represents [tl_x, tl_y, br_x, br_y].
  275. neg_bboxes (Tensor): Contains all the negative boxes,
  276. has shape (num_neg, 4), the last dimension 4
  277. represents [tl_x, tl_y, br_x, br_y].
  278. pos_gt_bboxes (Tensor): Contains all the gt_boxes,
  279. has shape (num_gt, 4), the last dimension 4
  280. represents [tl_x, tl_y, br_x, br_y].
  281. pos_gt_labels (Tensor): Contains all the gt_labels,
  282. has shape (num_gt).
  283. cfg (obj:`ConfigDict`): `train_cfg` of R-CNN.
  284. Returns:
  285. Tuple[Tensor]: Ground truth for proposals in a single image.
  286. Containing the following Tensors:
  287. - labels(Tensor): Gt_labels for all proposals, has
  288. shape (num_proposals,).
  289. - label_weights(Tensor): Labels_weights for all proposals, has
  290. shape (num_proposals,).
  291. - bbox_targets(Tensor):Regression target for all proposals, has
  292. shape (num_proposals, 4), the last dimension 4
  293. represents [tl_x, tl_y, br_x, br_y].
  294. - bbox_weights(Tensor):Regression weights for all proposals,
  295. has shape (num_proposals, 4).
  296. """
  297. num_pos = pos_bboxes.size(0)
  298. num_neg = neg_bboxes.size(0)
  299. num_samples = num_pos + num_neg
  300. # original implementation uses new_zeros since BG are set to be 0
  301. # now use empty & fill because BG cat_id = num_classes,
  302. # FG cat_id = [0, num_classes-1]
  303. labels = pos_bboxes.new_full((num_samples, ),
  304. self.num_classes,
  305. dtype=torch.long)
  306. label_weights = pos_bboxes.new_zeros(num_samples)
  307. bbox_targets = pos_bboxes.new_zeros(num_samples, 4)
  308. bbox_weights = pos_bboxes.new_zeros(num_samples, 4)
  309. if num_pos > 0:
  310. labels[pos_inds] = pos_gt_labels
  311. pos_weight = 1.0 if cfg.pos_weight <= 0 else cfg.pos_weight
  312. label_weights[pos_inds] = pos_weight
  313. if not self.reg_decoded_bbox:
  314. pos_bbox_targets = self.bbox_coder.encode(
  315. pos_bboxes, pos_gt_bboxes)
  316. else:
  317. pos_bbox_targets = pos_gt_bboxes
  318. bbox_targets[pos_inds, :] = pos_bbox_targets
  319. bbox_weights[pos_inds, :] = 1
  320. if num_neg > 0:
  321. label_weights[neg_inds] = 1.0
  322. return labels, label_weights, bbox_targets, bbox_weights
  323. def get_targets(self,
  324. sampling_results,
  325. gt_bboxes,
  326. gt_labels,
  327. rcnn_train_cfg,
  328. concat=True):
  329. """Calculate the ground truth for all samples in a batch according to
  330. the sampling_results.
  331. Almost the same as the implementation in bbox_head, we passed
  332. additional parameters pos_inds_list and neg_inds_list to
  333. `_get_target_single` function.
  334. Args:
  335. sampling_results (List[obj:SamplingResults]): Assign results of
  336. all images in a batch after sampling.
  337. gt_bboxes (list[Tensor]): Gt_bboxes of all images in a batch,
  338. each tensor has shape (num_gt, 4), the last dimension 4
  339. represents [tl_x, tl_y, br_x, br_y].
  340. gt_labels (list[Tensor]): Gt_labels of all images in a batch,
  341. each tensor has shape (num_gt,).
  342. rcnn_train_cfg (obj:`ConfigDict`): `train_cfg` of RCNN.
  343. concat (bool): Whether to concatenate the results of all
  344. the images in a single batch.
  345. Returns:
  346. Tuple[Tensor]: Ground truth for proposals in a single image.
  347. Containing the following list of Tensors:
  348. - labels (list[Tensor],Tensor): Gt_labels for all
  349. proposals in a batch, each tensor in list has
  350. shape (num_proposals,) when `concat=False`, otherwise just
  351. a single tensor has shape (num_all_proposals,).
  352. - label_weights (list[Tensor]): Labels_weights for
  353. all proposals in a batch, each tensor in list has shape
  354. (num_proposals,) when `concat=False`, otherwise just a
  355. single tensor has shape (num_all_proposals,).
  356. - bbox_targets (list[Tensor],Tensor): Regression target
  357. for all proposals in a batch, each tensor in list has
  358. shape (num_proposals, 4) when `concat=False`, otherwise
  359. just a single tensor has shape (num_all_proposals, 4),
  360. the last dimension 4 represents [tl_x, tl_y, br_x, br_y].
  361. - bbox_weights (list[tensor],Tensor): Regression weights for
  362. all proposals in a batch, each tensor in list has shape
  363. (num_proposals, 4) when `concat=False`, otherwise just a
  364. single tensor has shape (num_all_proposals, 4).
  365. """
  366. pos_inds_list = [res.pos_inds for res in sampling_results]
  367. neg_inds_list = [res.neg_inds for res in sampling_results]
  368. pos_bboxes_list = [res.pos_bboxes for res in sampling_results]
  369. neg_bboxes_list = [res.neg_bboxes for res in sampling_results]
  370. pos_gt_bboxes_list = [res.pos_gt_bboxes for res in sampling_results]
  371. pos_gt_labels_list = [res.pos_gt_labels for res in sampling_results]
  372. labels, label_weights, bbox_targets, bbox_weights = multi_apply(
  373. self._get_target_single,
  374. pos_inds_list,
  375. neg_inds_list,
  376. pos_bboxes_list,
  377. neg_bboxes_list,
  378. pos_gt_bboxes_list,
  379. pos_gt_labels_list,
  380. cfg=rcnn_train_cfg)
  381. if concat:
  382. labels = torch.cat(labels, 0)
  383. label_weights = torch.cat(label_weights, 0)
  384. bbox_targets = torch.cat(bbox_targets, 0)
  385. bbox_weights = torch.cat(bbox_weights, 0)
  386. return labels, label_weights, bbox_targets, bbox_weights

No Description

Contributors (3)