You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

bbox_head.py 26 kB

2 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593
  1. # Copyright (c) OpenMMLab. All rights reserved.
  2. import torch
  3. import torch.nn as nn
  4. import torch.nn.functional as F
  5. from mmcv.runner import BaseModule, auto_fp16, force_fp32
  6. from torch.nn.modules.utils import _pair
  7. from mmdet.core import build_bbox_coder, multi_apply, multiclass_nms
  8. from mmdet.models.builder import HEADS, build_loss
  9. from mmdet.models.losses import accuracy
  10. from mmdet.models.utils import build_linear_layer
  11. @HEADS.register_module()
  12. class BBoxHead(BaseModule):
  13. """Simplest RoI head, with only two fc layers for classification and
  14. regression respectively."""
  15. def __init__(self,
  16. with_avg_pool=False,
  17. with_cls=True,
  18. with_reg=True,
  19. roi_feat_size=7,
  20. in_channels=256,
  21. num_classes=80,
  22. bbox_coder=dict(
  23. type='DeltaXYWHBBoxCoder',
  24. clip_border=True,
  25. target_means=[0., 0., 0., 0.],
  26. target_stds=[0.1, 0.1, 0.2, 0.2]),
  27. reg_class_agnostic=False,
  28. reg_decoded_bbox=False,
  29. reg_predictor_cfg=dict(type='Linear'),
  30. cls_predictor_cfg=dict(type='Linear'),
  31. loss_cls=dict(
  32. type='CrossEntropyLoss',
  33. use_sigmoid=False,
  34. loss_weight=1.0),
  35. loss_bbox=dict(
  36. type='SmoothL1Loss', beta=1.0, loss_weight=1.0),
  37. init_cfg=None):
  38. super(BBoxHead, self).__init__(init_cfg)
  39. assert with_cls or with_reg
  40. self.with_avg_pool = with_avg_pool
  41. self.with_cls = with_cls
  42. self.with_reg = with_reg
  43. self.roi_feat_size = _pair(roi_feat_size)
  44. self.roi_feat_area = self.roi_feat_size[0] * self.roi_feat_size[1]
  45. self.in_channels = in_channels
  46. self.num_classes = num_classes
  47. self.reg_class_agnostic = reg_class_agnostic
  48. self.reg_decoded_bbox = reg_decoded_bbox
  49. self.reg_predictor_cfg = reg_predictor_cfg
  50. self.cls_predictor_cfg = cls_predictor_cfg
  51. self.fp16_enabled = False
  52. self.bbox_coder = build_bbox_coder(bbox_coder)
  53. self.loss_cls = build_loss(loss_cls)
  54. self.loss_bbox = build_loss(loss_bbox)
  55. in_channels = self.in_channels
  56. if self.with_avg_pool:
  57. self.avg_pool = nn.AvgPool2d(self.roi_feat_size)
  58. else:
  59. in_channels *= self.roi_feat_area
  60. if self.with_cls:
  61. # need to add background class
  62. if self.custom_cls_channels:
  63. cls_channels = self.loss_cls.get_cls_channels(self.num_classes)
  64. else:
  65. cls_channels = num_classes + 1
  66. self.fc_cls = build_linear_layer(
  67. self.cls_predictor_cfg,
  68. in_features=in_channels,
  69. out_features=cls_channels)
  70. if self.with_reg:
  71. out_dim_reg = 4 if reg_class_agnostic else 4 * num_classes
  72. self.fc_reg = build_linear_layer(
  73. self.reg_predictor_cfg,
  74. in_features=in_channels,
  75. out_features=out_dim_reg)
  76. self.debug_imgs = None
  77. if init_cfg is None:
  78. self.init_cfg = []
  79. if self.with_cls:
  80. self.init_cfg += [
  81. dict(
  82. type='Normal', std=0.01, override=dict(name='fc_cls'))
  83. ]
  84. if self.with_reg:
  85. self.init_cfg += [
  86. dict(
  87. type='Normal', std=0.001, override=dict(name='fc_reg'))
  88. ]
  89. @property
  90. def custom_cls_channels(self):
  91. return getattr(self.loss_cls, 'custom_cls_channels', False)
  92. @property
  93. def custom_activation(self):
  94. return getattr(self.loss_cls, 'custom_activation', False)
  95. @property
  96. def custom_accuracy(self):
  97. return getattr(self.loss_cls, 'custom_accuracy', False)
  98. @auto_fp16()
  99. def forward(self, x):
  100. if self.with_avg_pool:
  101. if x.numel() > 0:
  102. x = self.avg_pool(x)
  103. x = x.view(x.size(0), -1)
  104. else:
  105. # avg_pool does not support empty tensor,
  106. # so use torch.mean instead it
  107. x = torch.mean(x, dim=(-1, -2))
  108. cls_score = self.fc_cls(x) if self.with_cls else None
  109. bbox_pred = self.fc_reg(x) if self.with_reg else None
  110. return cls_score, bbox_pred
  111. def _get_target_single(self, pos_bboxes, neg_bboxes, pos_gt_bboxes,
  112. pos_gt_labels, cfg):
  113. """Calculate the ground truth for proposals in the single image
  114. according to the sampling results.
  115. Args:
  116. pos_bboxes (Tensor): Contains all the positive boxes,
  117. has shape (num_pos, 4), the last dimension 4
  118. represents [tl_x, tl_y, br_x, br_y].
  119. neg_bboxes (Tensor): Contains all the negative boxes,
  120. has shape (num_neg, 4), the last dimension 4
  121. represents [tl_x, tl_y, br_x, br_y].
  122. pos_gt_bboxes (Tensor): Contains all the gt_boxes,
  123. has shape (num_gt, 4), the last dimension 4
  124. represents [tl_x, tl_y, br_x, br_y].
  125. pos_gt_labels (Tensor): Contains all the gt_labels,
  126. has shape (num_gt).
  127. cfg (obj:`ConfigDict`): `train_cfg` of R-CNN.
  128. Returns:
  129. Tuple[Tensor]: Ground truth for proposals
  130. in a single image. Containing the following Tensors:
  131. - labels(Tensor): Gt_labels for all proposals, has
  132. shape (num_proposals,).
  133. - label_weights(Tensor): Labels_weights for all
  134. proposals, has shape (num_proposals,).
  135. - bbox_targets(Tensor):Regression target for all
  136. proposals, has shape (num_proposals, 4), the
  137. last dimension 4 represents [tl_x, tl_y, br_x, br_y].
  138. - bbox_weights(Tensor):Regression weights for all
  139. proposals, has shape (num_proposals, 4).
  140. """
  141. num_pos = pos_bboxes.size(0)
  142. num_neg = neg_bboxes.size(0)
  143. num_samples = num_pos + num_neg
  144. # original implementation uses new_zeros since BG are set to be 0
  145. # now use empty & fill because BG cat_id = num_classes,
  146. # FG cat_id = [0, num_classes-1]
  147. labels = pos_bboxes.new_full((num_samples, ),
  148. self.num_classes,
  149. dtype=torch.long)
  150. label_weights = pos_bboxes.new_zeros(num_samples)
  151. bbox_targets = pos_bboxes.new_zeros(num_samples, 4)
  152. bbox_weights = pos_bboxes.new_zeros(num_samples, 4)
  153. if num_pos > 0:
  154. labels[:num_pos] = pos_gt_labels
  155. pos_weight = 1.0 if cfg.pos_weight <= 0 else cfg.pos_weight
  156. label_weights[:num_pos] = pos_weight
  157. if not self.reg_decoded_bbox:
  158. pos_bbox_targets = self.bbox_coder.encode(
  159. pos_bboxes, pos_gt_bboxes)
  160. else:
  161. # When the regression loss (e.g. `IouLoss`, `GIouLoss`)
  162. # is applied directly on the decoded bounding boxes, both
  163. # the predicted boxes and regression targets should be with
  164. # absolute coordinate format.
  165. pos_bbox_targets = pos_gt_bboxes
  166. bbox_targets[:num_pos, :] = pos_bbox_targets
  167. bbox_weights[:num_pos, :] = 1
  168. if num_neg > 0:
  169. label_weights[-num_neg:] = 1.0
  170. return labels, label_weights, bbox_targets, bbox_weights
  171. def get_targets(self,
  172. sampling_results,
  173. gt_bboxes,
  174. gt_labels,
  175. rcnn_train_cfg,
  176. concat=True):
  177. """Calculate the ground truth for all samples in a batch according to
  178. the sampling_results.
  179. Almost the same as the implementation in bbox_head, we passed
  180. additional parameters pos_inds_list and neg_inds_list to
  181. `_get_target_single` function.
  182. Args:
  183. sampling_results (List[obj:SamplingResults]): Assign results of
  184. all images in a batch after sampling.
  185. gt_bboxes (list[Tensor]): Gt_bboxes of all images in a batch,
  186. each tensor has shape (num_gt, 4), the last dimension 4
  187. represents [tl_x, tl_y, br_x, br_y].
  188. gt_labels (list[Tensor]): Gt_labels of all images in a batch,
  189. each tensor has shape (num_gt,).
  190. rcnn_train_cfg (obj:ConfigDict): `train_cfg` of RCNN.
  191. concat (bool): Whether to concatenate the results of all
  192. the images in a single batch.
  193. Returns:
  194. Tuple[Tensor]: Ground truth for proposals in a single image.
  195. Containing the following list of Tensors:
  196. - labels (list[Tensor],Tensor): Gt_labels for all
  197. proposals in a batch, each tensor in list has
  198. shape (num_proposals,) when `concat=False`, otherwise
  199. just a single tensor has shape (num_all_proposals,).
  200. - label_weights (list[Tensor]): Labels_weights for
  201. all proposals in a batch, each tensor in list has
  202. shape (num_proposals,) when `concat=False`, otherwise
  203. just a single tensor has shape (num_all_proposals,).
  204. - bbox_targets (list[Tensor],Tensor): Regression target
  205. for all proposals in a batch, each tensor in list
  206. has shape (num_proposals, 4) when `concat=False`,
  207. otherwise just a single tensor has shape
  208. (num_all_proposals, 4), the last dimension 4 represents
  209. [tl_x, tl_y, br_x, br_y].
  210. - bbox_weights (list[tensor],Tensor): Regression weights for
  211. all proposals in a batch, each tensor in list has shape
  212. (num_proposals, 4) when `concat=False`, otherwise just a
  213. single tensor has shape (num_all_proposals, 4).
  214. """
  215. pos_bboxes_list = [res.pos_bboxes for res in sampling_results]
  216. neg_bboxes_list = [res.neg_bboxes for res in sampling_results]
  217. pos_gt_bboxes_list = [res.pos_gt_bboxes for res in sampling_results]
  218. pos_gt_labels_list = [res.pos_gt_labels for res in sampling_results]
  219. labels, label_weights, bbox_targets, bbox_weights = multi_apply(
  220. self._get_target_single,
  221. pos_bboxes_list,
  222. neg_bboxes_list,
  223. pos_gt_bboxes_list,
  224. pos_gt_labels_list,
  225. cfg=rcnn_train_cfg)
  226. if concat:
  227. labels = torch.cat(labels, 0)
  228. label_weights = torch.cat(label_weights, 0)
  229. bbox_targets = torch.cat(bbox_targets, 0)
  230. bbox_weights = torch.cat(bbox_weights, 0)
  231. return labels, label_weights, bbox_targets, bbox_weights
  232. @force_fp32(apply_to=('cls_score', 'bbox_pred'))
  233. def loss(self,
  234. cls_score,
  235. bbox_pred,
  236. rois,
  237. labels,
  238. label_weights,
  239. bbox_targets,
  240. bbox_weights,
  241. reduction_override=None):
  242. losses = dict()
  243. if cls_score is not None:
  244. avg_factor = max(torch.sum(label_weights > 0).float().item(), 1.)
  245. if cls_score.numel() > 0:
  246. loss_cls_ = self.loss_cls(
  247. cls_score,
  248. labels,
  249. label_weights,
  250. avg_factor=avg_factor,
  251. reduction_override=reduction_override)
  252. if isinstance(loss_cls_, dict):
  253. losses.update(loss_cls_)
  254. else:
  255. losses['loss_cls'] = loss_cls_
  256. if self.custom_activation:
  257. acc_ = self.loss_cls.get_accuracy(cls_score, labels)
  258. losses.update(acc_)
  259. else:
  260. losses['acc'] = accuracy(cls_score, labels)
  261. if bbox_pred is not None:
  262. bg_class_ind = self.num_classes
  263. # 0~self.num_classes-1 are FG, self.num_classes is BG
  264. pos_inds = (labels >= 0) & (labels < bg_class_ind)
  265. # do not perform bounding box regression for BG anymore.
  266. if pos_inds.any():
  267. if self.reg_decoded_bbox:
  268. # When the regression loss (e.g. `IouLoss`,
  269. # `GIouLoss`, `DIouLoss`) is applied directly on
  270. # the decoded bounding boxes, it decodes the
  271. # already encoded coordinates to absolute format.
  272. bbox_pred = self.bbox_coder.decode(rois[:, 1:], bbox_pred)
  273. if self.reg_class_agnostic:
  274. pos_bbox_pred = bbox_pred.view(
  275. bbox_pred.size(0), 4)[pos_inds.type(torch.bool)]
  276. else:
  277. pos_bbox_pred = bbox_pred.view(
  278. bbox_pred.size(0), -1,
  279. 4)[pos_inds.type(torch.bool),
  280. labels[pos_inds.type(torch.bool)]]
  281. losses['loss_bbox'] = self.loss_bbox(
  282. pos_bbox_pred,
  283. bbox_targets[pos_inds.type(torch.bool)],
  284. bbox_weights[pos_inds.type(torch.bool)],
  285. avg_factor=bbox_targets.size(0),
  286. reduction_override=reduction_override)
  287. else:
  288. losses['loss_bbox'] = bbox_pred[pos_inds].sum()
  289. return losses
  290. @force_fp32(apply_to=('cls_score', 'bbox_pred'))
  291. def get_bboxes(self,
  292. rois,
  293. cls_score,
  294. bbox_pred,
  295. img_shape,
  296. scale_factor,
  297. rescale=False,
  298. cfg=None):
  299. """Transform network output for a batch into bbox predictions.
  300. Args:
  301. rois (Tensor): Boxes to be transformed. Has shape (num_boxes, 5).
  302. last dimension 5 arrange as (batch_index, x1, y1, x2, y2).
  303. cls_score (Tensor): Box scores, has shape
  304. (num_boxes, num_classes + 1).
  305. bbox_pred (Tensor, optional): Box energies / deltas.
  306. has shape (num_boxes, num_classes * 4).
  307. img_shape (Sequence[int], optional): Maximum bounds for boxes,
  308. specifies (H, W, C) or (H, W).
  309. scale_factor (ndarray): Scale factor of the
  310. image arrange as (w_scale, h_scale, w_scale, h_scale).
  311. rescale (bool): If True, return boxes in original image space.
  312. Default: False.
  313. cfg (obj:`ConfigDict`): `test_cfg` of Bbox Head. Default: None
  314. Returns:
  315. tuple[Tensor, Tensor]:
  316. First tensor is `det_bboxes`, has the shape
  317. (num_boxes, 5) and last
  318. dimension 5 represent (tl_x, tl_y, br_x, br_y, score).
  319. Second tensor is the labels with shape (num_boxes, ).
  320. """
  321. # some loss (Seesaw loss..) may have custom activation
  322. if self.custom_cls_channels:
  323. scores = self.loss_cls.get_activation(cls_score)
  324. else:
  325. scores = F.softmax(
  326. cls_score, dim=-1) if cls_score is not None else None
  327. # bbox_pred would be None in some detector when with_reg is False,
  328. # e.g. Grid R-CNN.
  329. if bbox_pred is not None:
  330. bboxes = self.bbox_coder.decode(
  331. rois[..., 1:], bbox_pred, max_shape=img_shape)
  332. else:
  333. bboxes = rois[:, 1:].clone()
  334. if img_shape is not None:
  335. bboxes[:, [0, 2]].clamp_(min=0, max=img_shape[1])
  336. bboxes[:, [1, 3]].clamp_(min=0, max=img_shape[0])
  337. if rescale and bboxes.size(0) > 0:
  338. scale_factor = bboxes.new_tensor(scale_factor)
  339. bboxes = (bboxes.view(bboxes.size(0), -1, 4) / scale_factor).view(
  340. bboxes.size()[0], -1)
  341. if cfg is None:
  342. return bboxes, scores
  343. else:
  344. det_bboxes, det_labels = multiclass_nms(bboxes, scores,
  345. cfg.score_thr, cfg.nms,
  346. cfg.max_per_img)
  347. return det_bboxes, det_labels
  348. @force_fp32(apply_to=('bbox_preds', ))
  349. def refine_bboxes(self, rois, labels, bbox_preds, pos_is_gts, img_metas):
  350. """Refine bboxes during training.
  351. Args:
  352. rois (Tensor): Shape (n*bs, 5), where n is image number per GPU,
  353. and bs is the sampled RoIs per image. The first column is
  354. the image id and the next 4 columns are x1, y1, x2, y2.
  355. labels (Tensor): Shape (n*bs, ).
  356. bbox_preds (Tensor): Shape (n*bs, 4) or (n*bs, 4*#class).
  357. pos_is_gts (list[Tensor]): Flags indicating if each positive bbox
  358. is a gt bbox.
  359. img_metas (list[dict]): Meta info of each image.
  360. Returns:
  361. list[Tensor]: Refined bboxes of each image in a mini-batch.
  362. Example:
  363. >>> # xdoctest: +REQUIRES(module:kwarray)
  364. >>> import kwarray
  365. >>> import numpy as np
  366. >>> from mmdet.core.bbox.demodata import random_boxes
  367. >>> self = BBoxHead(reg_class_agnostic=True)
  368. >>> n_roi = 2
  369. >>> n_img = 4
  370. >>> scale = 512
  371. >>> rng = np.random.RandomState(0)
  372. >>> img_metas = [{'img_shape': (scale, scale)}
  373. ... for _ in range(n_img)]
  374. >>> # Create rois in the expected format
  375. >>> roi_boxes = random_boxes(n_roi, scale=scale, rng=rng)
  376. >>> img_ids = torch.randint(0, n_img, (n_roi,))
  377. >>> img_ids = img_ids.float()
  378. >>> rois = torch.cat([img_ids[:, None], roi_boxes], dim=1)
  379. >>> # Create other args
  380. >>> labels = torch.randint(0, 2, (n_roi,)).long()
  381. >>> bbox_preds = random_boxes(n_roi, scale=scale, rng=rng)
  382. >>> # For each image, pretend random positive boxes are gts
  383. >>> is_label_pos = (labels.numpy() > 0).astype(np.int)
  384. >>> lbl_per_img = kwarray.group_items(is_label_pos,
  385. ... img_ids.numpy())
  386. >>> pos_per_img = [sum(lbl_per_img.get(gid, []))
  387. ... for gid in range(n_img)]
  388. >>> pos_is_gts = [
  389. >>> torch.randint(0, 2, (npos,)).byte().sort(
  390. >>> descending=True)[0]
  391. >>> for npos in pos_per_img
  392. >>> ]
  393. >>> bboxes_list = self.refine_bboxes(rois, labels, bbox_preds,
  394. >>> pos_is_gts, img_metas)
  395. >>> print(bboxes_list)
  396. """
  397. img_ids = rois[:, 0].long().unique(sorted=True)
  398. assert img_ids.numel() <= len(img_metas)
  399. bboxes_list = []
  400. for i in range(len(img_metas)):
  401. inds = torch.nonzero(
  402. rois[:, 0] == i, as_tuple=False).squeeze(dim=1)
  403. num_rois = inds.numel()
  404. bboxes_ = rois[inds, 1:]
  405. label_ = labels[inds]
  406. bbox_pred_ = bbox_preds[inds]
  407. img_meta_ = img_metas[i]
  408. pos_is_gts_ = pos_is_gts[i]
  409. bboxes = self.regress_by_class(bboxes_, label_, bbox_pred_,
  410. img_meta_)
  411. # filter gt bboxes
  412. pos_keep = 1 - pos_is_gts_
  413. keep_inds = pos_is_gts_.new_ones(num_rois)
  414. keep_inds[:len(pos_is_gts_)] = pos_keep
  415. bboxes_list.append(bboxes[keep_inds.type(torch.bool)])
  416. return bboxes_list
  417. @force_fp32(apply_to=('bbox_pred', ))
  418. def regress_by_class(self, rois, label, bbox_pred, img_meta):
  419. """Regress the bbox for the predicted class. Used in Cascade R-CNN.
  420. Args:
  421. rois (Tensor): Rois from `rpn_head` or last stage
  422. `bbox_head`, has shape (num_proposals, 4) or
  423. (num_proposals, 5).
  424. label (Tensor): Only used when `self.reg_class_agnostic`
  425. is False, has shape (num_proposals, ).
  426. bbox_pred (Tensor): Regression prediction of
  427. current stage `bbox_head`. When `self.reg_class_agnostic`
  428. is False, it has shape (n, num_classes * 4), otherwise
  429. it has shape (n, 4).
  430. img_meta (dict): Image meta info.
  431. Returns:
  432. Tensor: Regressed bboxes, the same shape as input rois.
  433. """
  434. assert rois.size(1) == 4 or rois.size(1) == 5, repr(rois.shape)
  435. if not self.reg_class_agnostic:
  436. label = label * 4
  437. inds = torch.stack((label, label + 1, label + 2, label + 3), 1)
  438. bbox_pred = torch.gather(bbox_pred, 1, inds)
  439. assert bbox_pred.size(1) == 4
  440. max_shape = img_meta['img_shape']
  441. if rois.size(1) == 4:
  442. new_rois = self.bbox_coder.decode(
  443. rois, bbox_pred, max_shape=max_shape)
  444. else:
  445. bboxes = self.bbox_coder.decode(
  446. rois[:, 1:], bbox_pred, max_shape=max_shape)
  447. new_rois = torch.cat((rois[:, [0]], bboxes), dim=1)
  448. return new_rois
  449. def onnx_export(self,
  450. rois,
  451. cls_score,
  452. bbox_pred,
  453. img_shape,
  454. cfg=None,
  455. **kwargs):
  456. """Transform network output for a batch into bbox predictions.
  457. Args:
  458. rois (Tensor): Boxes to be transformed.
  459. Has shape (B, num_boxes, 5)
  460. cls_score (Tensor): Box scores. has shape
  461. (B, num_boxes, num_classes + 1), 1 represent the background.
  462. bbox_pred (Tensor, optional): Box energies / deltas for,
  463. has shape (B, num_boxes, num_classes * 4) when.
  464. img_shape (torch.Tensor): Shape of image.
  465. cfg (obj:`ConfigDict`): `test_cfg` of Bbox Head. Default: None
  466. Returns:
  467. tuple[Tensor, Tensor]: dets of shape [N, num_det, 5]
  468. and class labels of shape [N, num_det].
  469. """
  470. assert rois.ndim == 3, 'Only support export two stage ' \
  471. 'model to ONNX ' \
  472. 'with batch dimension. '
  473. if self.custom_cls_channels:
  474. scores = self.loss_cls.get_activation(cls_score)
  475. else:
  476. scores = F.softmax(
  477. cls_score, dim=-1) if cls_score is not None else None
  478. if bbox_pred is not None:
  479. bboxes = self.bbox_coder.decode(
  480. rois[..., 1:], bbox_pred, max_shape=img_shape)
  481. else:
  482. bboxes = rois[..., 1:].clone()
  483. if img_shape is not None:
  484. max_shape = bboxes.new_tensor(img_shape)[..., :2]
  485. min_xy = bboxes.new_tensor(0)
  486. max_xy = torch.cat(
  487. [max_shape] * 2, dim=-1).flip(-1).unsqueeze(-2)
  488. bboxes = torch.where(bboxes < min_xy, min_xy, bboxes)
  489. bboxes = torch.where(bboxes > max_xy, max_xy, bboxes)
  490. # Replace multiclass_nms with ONNX::NonMaxSuppression in deployment
  491. from mmdet.core.export import add_dummy_nms_for_onnx
  492. max_output_boxes_per_class = cfg.nms.get('max_output_boxes_per_class',
  493. cfg.max_per_img)
  494. iou_threshold = cfg.nms.get('iou_threshold', 0.5)
  495. score_threshold = cfg.score_thr
  496. nms_pre = cfg.get('deploy_nms_pre', -1)
  497. scores = scores[..., :self.num_classes]
  498. if self.reg_class_agnostic:
  499. return add_dummy_nms_for_onnx(
  500. bboxes,
  501. scores,
  502. max_output_boxes_per_class,
  503. iou_threshold,
  504. score_threshold,
  505. pre_top_k=nms_pre,
  506. after_top_k=cfg.max_per_img)
  507. else:
  508. batch_size = scores.shape[0]
  509. labels = torch.arange(
  510. self.num_classes, dtype=torch.long).to(scores.device)
  511. labels = labels.view(1, 1, -1).expand_as(scores)
  512. labels = labels.reshape(batch_size, -1)
  513. scores = scores.reshape(batch_size, -1)
  514. bboxes = bboxes.reshape(batch_size, -1, 4)
  515. max_size = torch.max(img_shape)
  516. # Offset bboxes of each class so that bboxes of different labels
  517. # do not overlap.
  518. offsets = (labels * max_size + 1).unsqueeze(2)
  519. bboxes_for_nms = bboxes + offsets
  520. batch_dets, labels = add_dummy_nms_for_onnx(
  521. bboxes_for_nms,
  522. scores.unsqueeze(2),
  523. max_output_boxes_per_class,
  524. iou_threshold,
  525. score_threshold,
  526. pre_top_k=nms_pre,
  527. after_top_k=cfg.max_per_img,
  528. labels=labels)
  529. # Offset the bboxes back after dummy nms.
  530. offsets = (labels * max_size + 1).unsqueeze(2)
  531. # Indexing + inplace operation fails with dynamic shape in ONNX
  532. # original style: batch_dets[..., :4] -= offsets
  533. bboxes, scores = batch_dets[..., 0:4], batch_dets[..., 4:5]
  534. bboxes -= offsets
  535. batch_dets = torch.cat([bboxes, scores], dim=2)
  536. return batch_dets, labels

No Description

Contributors (2)