You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

sabl_head.py 25 kB

2 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584
  1. # Copyright (c) OpenMMLab. All rights reserved.
  2. import numpy as np
  3. import torch
  4. import torch.nn as nn
  5. import torch.nn.functional as F
  6. from mmcv.cnn import ConvModule
  7. from mmcv.runner import BaseModule, force_fp32
  8. from mmdet.core import build_bbox_coder, multi_apply, multiclass_nms
  9. from mmdet.models.builder import HEADS, build_loss
  10. from mmdet.models.losses import accuracy
  11. @HEADS.register_module()
  12. class SABLHead(BaseModule):
  13. """Side-Aware Boundary Localization (SABL) for RoI-Head.
  14. Side-Aware features are extracted by conv layers
  15. with an attention mechanism.
  16. Boundary Localization with Bucketing and Bucketing Guided Rescoring
  17. are implemented in BucketingBBoxCoder.
  18. Please refer to https://arxiv.org/abs/1912.04260 for more details.
  19. Args:
  20. cls_in_channels (int): Input channels of cls RoI feature. \
  21. Defaults to 256.
  22. reg_in_channels (int): Input channels of reg RoI feature. \
  23. Defaults to 256.
  24. roi_feat_size (int): Size of RoI features. Defaults to 7.
  25. reg_feat_up_ratio (int): Upsample ratio of reg features. \
  26. Defaults to 2.
  27. reg_pre_kernel (int): Kernel of 2D conv layers before \
  28. attention pooling. Defaults to 3.
  29. reg_post_kernel (int): Kernel of 1D conv layers after \
  30. attention pooling. Defaults to 3.
  31. reg_pre_num (int): Number of pre convs. Defaults to 2.
  32. reg_post_num (int): Number of post convs. Defaults to 1.
  33. num_classes (int): Number of classes in dataset. Defaults to 80.
  34. cls_out_channels (int): Hidden channels in cls fcs. Defaults to 1024.
  35. reg_offset_out_channels (int): Hidden and output channel \
  36. of reg offset branch. Defaults to 256.
  37. reg_cls_out_channels (int): Hidden and output channel \
  38. of reg cls branch. Defaults to 256.
  39. num_cls_fcs (int): Number of fcs for cls branch. Defaults to 1.
  40. num_reg_fcs (int): Number of fcs for reg branch.. Defaults to 0.
  41. reg_class_agnostic (bool): Class agnostic regression or not. \
  42. Defaults to True.
  43. norm_cfg (dict): Config of norm layers. Defaults to None.
  44. bbox_coder (dict): Config of bbox coder. Defaults 'BucketingBBoxCoder'.
  45. loss_cls (dict): Config of classification loss.
  46. loss_bbox_cls (dict): Config of classification loss for bbox branch.
  47. loss_bbox_reg (dict): Config of regression loss for bbox branch.
  48. init_cfg (dict or list[dict], optional): Initialization config dict.
  49. Default: None
  50. """
  51. def __init__(self,
  52. num_classes,
  53. cls_in_channels=256,
  54. reg_in_channels=256,
  55. roi_feat_size=7,
  56. reg_feat_up_ratio=2,
  57. reg_pre_kernel=3,
  58. reg_post_kernel=3,
  59. reg_pre_num=2,
  60. reg_post_num=1,
  61. cls_out_channels=1024,
  62. reg_offset_out_channels=256,
  63. reg_cls_out_channels=256,
  64. num_cls_fcs=1,
  65. num_reg_fcs=0,
  66. reg_class_agnostic=True,
  67. norm_cfg=None,
  68. bbox_coder=dict(
  69. type='BucketingBBoxCoder',
  70. num_buckets=14,
  71. scale_factor=1.7),
  72. loss_cls=dict(
  73. type='CrossEntropyLoss',
  74. use_sigmoid=False,
  75. loss_weight=1.0),
  76. loss_bbox_cls=dict(
  77. type='CrossEntropyLoss',
  78. use_sigmoid=True,
  79. loss_weight=1.0),
  80. loss_bbox_reg=dict(
  81. type='SmoothL1Loss', beta=0.1, loss_weight=1.0),
  82. init_cfg=None):
  83. super(SABLHead, self).__init__(init_cfg)
  84. self.cls_in_channels = cls_in_channels
  85. self.reg_in_channels = reg_in_channels
  86. self.roi_feat_size = roi_feat_size
  87. self.reg_feat_up_ratio = int(reg_feat_up_ratio)
  88. self.num_buckets = bbox_coder['num_buckets']
  89. assert self.reg_feat_up_ratio // 2 >= 1
  90. self.up_reg_feat_size = roi_feat_size * self.reg_feat_up_ratio
  91. assert self.up_reg_feat_size == bbox_coder['num_buckets']
  92. self.reg_pre_kernel = reg_pre_kernel
  93. self.reg_post_kernel = reg_post_kernel
  94. self.reg_pre_num = reg_pre_num
  95. self.reg_post_num = reg_post_num
  96. self.num_classes = num_classes
  97. self.cls_out_channels = cls_out_channels
  98. self.reg_offset_out_channels = reg_offset_out_channels
  99. self.reg_cls_out_channels = reg_cls_out_channels
  100. self.num_cls_fcs = num_cls_fcs
  101. self.num_reg_fcs = num_reg_fcs
  102. self.reg_class_agnostic = reg_class_agnostic
  103. assert self.reg_class_agnostic
  104. self.norm_cfg = norm_cfg
  105. self.bbox_coder = build_bbox_coder(bbox_coder)
  106. self.loss_cls = build_loss(loss_cls)
  107. self.loss_bbox_cls = build_loss(loss_bbox_cls)
  108. self.loss_bbox_reg = build_loss(loss_bbox_reg)
  109. self.cls_fcs = self._add_fc_branch(self.num_cls_fcs,
  110. self.cls_in_channels,
  111. self.roi_feat_size,
  112. self.cls_out_channels)
  113. self.side_num = int(np.ceil(self.num_buckets / 2))
  114. if self.reg_feat_up_ratio > 1:
  115. self.upsample_x = nn.ConvTranspose1d(
  116. reg_in_channels,
  117. reg_in_channels,
  118. self.reg_feat_up_ratio,
  119. stride=self.reg_feat_up_ratio)
  120. self.upsample_y = nn.ConvTranspose1d(
  121. reg_in_channels,
  122. reg_in_channels,
  123. self.reg_feat_up_ratio,
  124. stride=self.reg_feat_up_ratio)
  125. self.reg_pre_convs = nn.ModuleList()
  126. for i in range(self.reg_pre_num):
  127. reg_pre_conv = ConvModule(
  128. reg_in_channels,
  129. reg_in_channels,
  130. kernel_size=reg_pre_kernel,
  131. padding=reg_pre_kernel // 2,
  132. norm_cfg=norm_cfg,
  133. act_cfg=dict(type='ReLU'))
  134. self.reg_pre_convs.append(reg_pre_conv)
  135. self.reg_post_conv_xs = nn.ModuleList()
  136. for i in range(self.reg_post_num):
  137. reg_post_conv_x = ConvModule(
  138. reg_in_channels,
  139. reg_in_channels,
  140. kernel_size=(1, reg_post_kernel),
  141. padding=(0, reg_post_kernel // 2),
  142. norm_cfg=norm_cfg,
  143. act_cfg=dict(type='ReLU'))
  144. self.reg_post_conv_xs.append(reg_post_conv_x)
  145. self.reg_post_conv_ys = nn.ModuleList()
  146. for i in range(self.reg_post_num):
  147. reg_post_conv_y = ConvModule(
  148. reg_in_channels,
  149. reg_in_channels,
  150. kernel_size=(reg_post_kernel, 1),
  151. padding=(reg_post_kernel // 2, 0),
  152. norm_cfg=norm_cfg,
  153. act_cfg=dict(type='ReLU'))
  154. self.reg_post_conv_ys.append(reg_post_conv_y)
  155. self.reg_conv_att_x = nn.Conv2d(reg_in_channels, 1, 1)
  156. self.reg_conv_att_y = nn.Conv2d(reg_in_channels, 1, 1)
  157. self.fc_cls = nn.Linear(self.cls_out_channels, self.num_classes + 1)
  158. self.relu = nn.ReLU(inplace=True)
  159. self.reg_cls_fcs = self._add_fc_branch(self.num_reg_fcs,
  160. self.reg_in_channels, 1,
  161. self.reg_cls_out_channels)
  162. self.reg_offset_fcs = self._add_fc_branch(self.num_reg_fcs,
  163. self.reg_in_channels, 1,
  164. self.reg_offset_out_channels)
  165. self.fc_reg_cls = nn.Linear(self.reg_cls_out_channels, 1)
  166. self.fc_reg_offset = nn.Linear(self.reg_offset_out_channels, 1)
  167. if init_cfg is None:
  168. self.init_cfg = [
  169. dict(
  170. type='Xavier',
  171. layer='Linear',
  172. distribution='uniform',
  173. override=[
  174. dict(type='Normal', name='reg_conv_att_x', std=0.01),
  175. dict(type='Normal', name='reg_conv_att_y', std=0.01),
  176. dict(type='Normal', name='fc_reg_cls', std=0.01),
  177. dict(type='Normal', name='fc_cls', std=0.01),
  178. dict(type='Normal', name='fc_reg_offset', std=0.001)
  179. ])
  180. ]
  181. if self.reg_feat_up_ratio > 1:
  182. self.init_cfg += [
  183. dict(
  184. type='Kaiming',
  185. distribution='normal',
  186. override=[
  187. dict(name='upsample_x'),
  188. dict(name='upsample_y')
  189. ])
  190. ]
  191. def _add_fc_branch(self, num_branch_fcs, in_channels, roi_feat_size,
  192. fc_out_channels):
  193. in_channels = in_channels * roi_feat_size * roi_feat_size
  194. branch_fcs = nn.ModuleList()
  195. for i in range(num_branch_fcs):
  196. fc_in_channels = (in_channels if i == 0 else fc_out_channels)
  197. branch_fcs.append(nn.Linear(fc_in_channels, fc_out_channels))
  198. return branch_fcs
  199. def cls_forward(self, cls_x):
  200. cls_x = cls_x.view(cls_x.size(0), -1)
  201. for fc in self.cls_fcs:
  202. cls_x = self.relu(fc(cls_x))
  203. cls_score = self.fc_cls(cls_x)
  204. return cls_score
  205. def attention_pool(self, reg_x):
  206. """Extract direction-specific features fx and fy with attention
  207. methanism."""
  208. reg_fx = reg_x
  209. reg_fy = reg_x
  210. reg_fx_att = self.reg_conv_att_x(reg_fx).sigmoid()
  211. reg_fy_att = self.reg_conv_att_y(reg_fy).sigmoid()
  212. reg_fx_att = reg_fx_att / reg_fx_att.sum(dim=2).unsqueeze(2)
  213. reg_fy_att = reg_fy_att / reg_fy_att.sum(dim=3).unsqueeze(3)
  214. reg_fx = (reg_fx * reg_fx_att).sum(dim=2)
  215. reg_fy = (reg_fy * reg_fy_att).sum(dim=3)
  216. return reg_fx, reg_fy
  217. def side_aware_feature_extractor(self, reg_x):
  218. """Refine and extract side-aware features without split them."""
  219. for reg_pre_conv in self.reg_pre_convs:
  220. reg_x = reg_pre_conv(reg_x)
  221. reg_fx, reg_fy = self.attention_pool(reg_x)
  222. if self.reg_post_num > 0:
  223. reg_fx = reg_fx.unsqueeze(2)
  224. reg_fy = reg_fy.unsqueeze(3)
  225. for i in range(self.reg_post_num):
  226. reg_fx = self.reg_post_conv_xs[i](reg_fx)
  227. reg_fy = self.reg_post_conv_ys[i](reg_fy)
  228. reg_fx = reg_fx.squeeze(2)
  229. reg_fy = reg_fy.squeeze(3)
  230. if self.reg_feat_up_ratio > 1:
  231. reg_fx = self.relu(self.upsample_x(reg_fx))
  232. reg_fy = self.relu(self.upsample_y(reg_fy))
  233. reg_fx = torch.transpose(reg_fx, 1, 2)
  234. reg_fy = torch.transpose(reg_fy, 1, 2)
  235. return reg_fx.contiguous(), reg_fy.contiguous()
  236. def reg_pred(self, x, offset_fcs, cls_fcs):
  237. """Predict bucketing estimation (cls_pred) and fine regression (offset
  238. pred) with side-aware features."""
  239. x_offset = x.view(-1, self.reg_in_channels)
  240. x_cls = x.view(-1, self.reg_in_channels)
  241. for fc in offset_fcs:
  242. x_offset = self.relu(fc(x_offset))
  243. for fc in cls_fcs:
  244. x_cls = self.relu(fc(x_cls))
  245. offset_pred = self.fc_reg_offset(x_offset)
  246. cls_pred = self.fc_reg_cls(x_cls)
  247. offset_pred = offset_pred.view(x.size(0), -1)
  248. cls_pred = cls_pred.view(x.size(0), -1)
  249. return offset_pred, cls_pred
  250. def side_aware_split(self, feat):
  251. """Split side-aware features aligned with orders of bucketing
  252. targets."""
  253. l_end = int(np.ceil(self.up_reg_feat_size / 2))
  254. r_start = int(np.floor(self.up_reg_feat_size / 2))
  255. feat_fl = feat[:, :l_end]
  256. feat_fr = feat[:, r_start:].flip(dims=(1, ))
  257. feat_fl = feat_fl.contiguous()
  258. feat_fr = feat_fr.contiguous()
  259. feat = torch.cat([feat_fl, feat_fr], dim=-1)
  260. return feat
  261. def bbox_pred_split(self, bbox_pred, num_proposals_per_img):
  262. """Split batch bbox prediction back to each image."""
  263. bucket_cls_preds, bucket_offset_preds = bbox_pred
  264. bucket_cls_preds = bucket_cls_preds.split(num_proposals_per_img, 0)
  265. bucket_offset_preds = bucket_offset_preds.split(
  266. num_proposals_per_img, 0)
  267. bbox_pred = tuple(zip(bucket_cls_preds, bucket_offset_preds))
  268. return bbox_pred
  269. def reg_forward(self, reg_x):
  270. outs = self.side_aware_feature_extractor(reg_x)
  271. edge_offset_preds = []
  272. edge_cls_preds = []
  273. reg_fx = outs[0]
  274. reg_fy = outs[1]
  275. offset_pred_x, cls_pred_x = self.reg_pred(reg_fx, self.reg_offset_fcs,
  276. self.reg_cls_fcs)
  277. offset_pred_y, cls_pred_y = self.reg_pred(reg_fy, self.reg_offset_fcs,
  278. self.reg_cls_fcs)
  279. offset_pred_x = self.side_aware_split(offset_pred_x)
  280. offset_pred_y = self.side_aware_split(offset_pred_y)
  281. cls_pred_x = self.side_aware_split(cls_pred_x)
  282. cls_pred_y = self.side_aware_split(cls_pred_y)
  283. edge_offset_preds = torch.cat([offset_pred_x, offset_pred_y], dim=-1)
  284. edge_cls_preds = torch.cat([cls_pred_x, cls_pred_y], dim=-1)
  285. return (edge_cls_preds, edge_offset_preds)
  286. def forward(self, x):
  287. bbox_pred = self.reg_forward(x)
  288. cls_score = self.cls_forward(x)
  289. return cls_score, bbox_pred
  290. def get_targets(self, sampling_results, gt_bboxes, gt_labels,
  291. rcnn_train_cfg):
  292. pos_proposals = [res.pos_bboxes for res in sampling_results]
  293. neg_proposals = [res.neg_bboxes for res in sampling_results]
  294. pos_gt_bboxes = [res.pos_gt_bboxes for res in sampling_results]
  295. pos_gt_labels = [res.pos_gt_labels for res in sampling_results]
  296. cls_reg_targets = self.bucket_target(pos_proposals, neg_proposals,
  297. pos_gt_bboxes, pos_gt_labels,
  298. rcnn_train_cfg)
  299. (labels, label_weights, bucket_cls_targets, bucket_cls_weights,
  300. bucket_offset_targets, bucket_offset_weights) = cls_reg_targets
  301. return (labels, label_weights, (bucket_cls_targets,
  302. bucket_offset_targets),
  303. (bucket_cls_weights, bucket_offset_weights))
  304. def bucket_target(self,
  305. pos_proposals_list,
  306. neg_proposals_list,
  307. pos_gt_bboxes_list,
  308. pos_gt_labels_list,
  309. rcnn_train_cfg,
  310. concat=True):
  311. (labels, label_weights, bucket_cls_targets, bucket_cls_weights,
  312. bucket_offset_targets, bucket_offset_weights) = multi_apply(
  313. self._bucket_target_single,
  314. pos_proposals_list,
  315. neg_proposals_list,
  316. pos_gt_bboxes_list,
  317. pos_gt_labels_list,
  318. cfg=rcnn_train_cfg)
  319. if concat:
  320. labels = torch.cat(labels, 0)
  321. label_weights = torch.cat(label_weights, 0)
  322. bucket_cls_targets = torch.cat(bucket_cls_targets, 0)
  323. bucket_cls_weights = torch.cat(bucket_cls_weights, 0)
  324. bucket_offset_targets = torch.cat(bucket_offset_targets, 0)
  325. bucket_offset_weights = torch.cat(bucket_offset_weights, 0)
  326. return (labels, label_weights, bucket_cls_targets, bucket_cls_weights,
  327. bucket_offset_targets, bucket_offset_weights)
  328. def _bucket_target_single(self, pos_proposals, neg_proposals,
  329. pos_gt_bboxes, pos_gt_labels, cfg):
  330. """Compute bucketing estimation targets and fine regression targets for
  331. a single image.
  332. Args:
  333. pos_proposals (Tensor): positive proposals of a single image,
  334. Shape (n_pos, 4)
  335. neg_proposals (Tensor): negative proposals of a single image,
  336. Shape (n_neg, 4).
  337. pos_gt_bboxes (Tensor): gt bboxes assigned to positive proposals
  338. of a single image, Shape (n_pos, 4).
  339. pos_gt_labels (Tensor): gt labels assigned to positive proposals
  340. of a single image, Shape (n_pos, ).
  341. cfg (dict): Config of calculating targets
  342. Returns:
  343. tuple:
  344. - labels (Tensor): Labels in a single image. \
  345. Shape (n,).
  346. - label_weights (Tensor): Label weights in a single image.\
  347. Shape (n,)
  348. - bucket_cls_targets (Tensor): Bucket cls targets in \
  349. a single image. Shape (n, num_buckets*2).
  350. - bucket_cls_weights (Tensor): Bucket cls weights in \
  351. a single image. Shape (n, num_buckets*2).
  352. - bucket_offset_targets (Tensor): Bucket offset targets \
  353. in a single image. Shape (n, num_buckets*2).
  354. - bucket_offset_targets (Tensor): Bucket offset weights \
  355. in a single image. Shape (n, num_buckets*2).
  356. """
  357. num_pos = pos_proposals.size(0)
  358. num_neg = neg_proposals.size(0)
  359. num_samples = num_pos + num_neg
  360. labels = pos_gt_bboxes.new_full((num_samples, ),
  361. self.num_classes,
  362. dtype=torch.long)
  363. label_weights = pos_proposals.new_zeros(num_samples)
  364. bucket_cls_targets = pos_proposals.new_zeros(num_samples,
  365. 4 * self.side_num)
  366. bucket_cls_weights = pos_proposals.new_zeros(num_samples,
  367. 4 * self.side_num)
  368. bucket_offset_targets = pos_proposals.new_zeros(
  369. num_samples, 4 * self.side_num)
  370. bucket_offset_weights = pos_proposals.new_zeros(
  371. num_samples, 4 * self.side_num)
  372. if num_pos > 0:
  373. labels[:num_pos] = pos_gt_labels
  374. label_weights[:num_pos] = 1.0
  375. (pos_bucket_offset_targets, pos_bucket_offset_weights,
  376. pos_bucket_cls_targets,
  377. pos_bucket_cls_weights) = self.bbox_coder.encode(
  378. pos_proposals, pos_gt_bboxes)
  379. bucket_cls_targets[:num_pos, :] = pos_bucket_cls_targets
  380. bucket_cls_weights[:num_pos, :] = pos_bucket_cls_weights
  381. bucket_offset_targets[:num_pos, :] = pos_bucket_offset_targets
  382. bucket_offset_weights[:num_pos, :] = pos_bucket_offset_weights
  383. if num_neg > 0:
  384. label_weights[-num_neg:] = 1.0
  385. return (labels, label_weights, bucket_cls_targets, bucket_cls_weights,
  386. bucket_offset_targets, bucket_offset_weights)
  387. def loss(self,
  388. cls_score,
  389. bbox_pred,
  390. rois,
  391. labels,
  392. label_weights,
  393. bbox_targets,
  394. bbox_weights,
  395. reduction_override=None):
  396. losses = dict()
  397. if cls_score is not None:
  398. avg_factor = max(torch.sum(label_weights > 0).float().item(), 1.)
  399. losses['loss_cls'] = self.loss_cls(
  400. cls_score,
  401. labels,
  402. label_weights,
  403. avg_factor=avg_factor,
  404. reduction_override=reduction_override)
  405. losses['acc'] = accuracy(cls_score, labels)
  406. if bbox_pred is not None:
  407. bucket_cls_preds, bucket_offset_preds = bbox_pred
  408. bucket_cls_targets, bucket_offset_targets = bbox_targets
  409. bucket_cls_weights, bucket_offset_weights = bbox_weights
  410. # edge cls
  411. bucket_cls_preds = bucket_cls_preds.view(-1, self.side_num)
  412. bucket_cls_targets = bucket_cls_targets.view(-1, self.side_num)
  413. bucket_cls_weights = bucket_cls_weights.view(-1, self.side_num)
  414. losses['loss_bbox_cls'] = self.loss_bbox_cls(
  415. bucket_cls_preds,
  416. bucket_cls_targets,
  417. bucket_cls_weights,
  418. avg_factor=bucket_cls_targets.size(0),
  419. reduction_override=reduction_override)
  420. losses['loss_bbox_reg'] = self.loss_bbox_reg(
  421. bucket_offset_preds,
  422. bucket_offset_targets,
  423. bucket_offset_weights,
  424. avg_factor=bucket_offset_targets.size(0),
  425. reduction_override=reduction_override)
  426. return losses
  427. @force_fp32(apply_to=('cls_score', 'bbox_pred'))
  428. def get_bboxes(self,
  429. rois,
  430. cls_score,
  431. bbox_pred,
  432. img_shape,
  433. scale_factor,
  434. rescale=False,
  435. cfg=None):
  436. if isinstance(cls_score, list):
  437. cls_score = sum(cls_score) / float(len(cls_score))
  438. scores = F.softmax(cls_score, dim=1) if cls_score is not None else None
  439. if bbox_pred is not None:
  440. bboxes, confidences = self.bbox_coder.decode(
  441. rois[:, 1:], bbox_pred, img_shape)
  442. else:
  443. bboxes = rois[:, 1:].clone()
  444. confidences = None
  445. if img_shape is not None:
  446. bboxes[:, [0, 2]].clamp_(min=0, max=img_shape[1] - 1)
  447. bboxes[:, [1, 3]].clamp_(min=0, max=img_shape[0] - 1)
  448. if rescale and bboxes.size(0) > 0:
  449. if isinstance(scale_factor, float):
  450. bboxes /= scale_factor
  451. else:
  452. bboxes /= torch.from_numpy(scale_factor).to(bboxes.device)
  453. if cfg is None:
  454. return bboxes, scores
  455. else:
  456. det_bboxes, det_labels = multiclass_nms(
  457. bboxes,
  458. scores,
  459. cfg.score_thr,
  460. cfg.nms,
  461. cfg.max_per_img,
  462. score_factors=confidences)
  463. return det_bboxes, det_labels
  464. @force_fp32(apply_to=('bbox_preds', ))
  465. def refine_bboxes(self, rois, labels, bbox_preds, pos_is_gts, img_metas):
  466. """Refine bboxes during training.
  467. Args:
  468. rois (Tensor): Shape (n*bs, 5), where n is image number per GPU,
  469. and bs is the sampled RoIs per image.
  470. labels (Tensor): Shape (n*bs, ).
  471. bbox_preds (list[Tensor]): Shape [(n*bs, num_buckets*2), \
  472. (n*bs, num_buckets*2)].
  473. pos_is_gts (list[Tensor]): Flags indicating if each positive bbox
  474. is a gt bbox.
  475. img_metas (list[dict]): Meta info of each image.
  476. Returns:
  477. list[Tensor]: Refined bboxes of each image in a mini-batch.
  478. """
  479. img_ids = rois[:, 0].long().unique(sorted=True)
  480. assert img_ids.numel() == len(img_metas)
  481. bboxes_list = []
  482. for i in range(len(img_metas)):
  483. inds = torch.nonzero(
  484. rois[:, 0] == i, as_tuple=False).squeeze(dim=1)
  485. num_rois = inds.numel()
  486. bboxes_ = rois[inds, 1:]
  487. label_ = labels[inds]
  488. edge_cls_preds, edge_offset_preds = bbox_preds
  489. edge_cls_preds_ = edge_cls_preds[inds]
  490. edge_offset_preds_ = edge_offset_preds[inds]
  491. bbox_pred_ = [edge_cls_preds_, edge_offset_preds_]
  492. img_meta_ = img_metas[i]
  493. pos_is_gts_ = pos_is_gts[i]
  494. bboxes = self.regress_by_class(bboxes_, label_, bbox_pred_,
  495. img_meta_)
  496. # filter gt bboxes
  497. pos_keep = 1 - pos_is_gts_
  498. keep_inds = pos_is_gts_.new_ones(num_rois)
  499. keep_inds[:len(pos_is_gts_)] = pos_keep
  500. bboxes_list.append(bboxes[keep_inds.type(torch.bool)])
  501. return bboxes_list
  502. @force_fp32(apply_to=('bbox_pred', ))
  503. def regress_by_class(self, rois, label, bbox_pred, img_meta):
  504. """Regress the bbox for the predicted class. Used in Cascade R-CNN.
  505. Args:
  506. rois (Tensor): shape (n, 4) or (n, 5)
  507. label (Tensor): shape (n, )
  508. bbox_pred (list[Tensor]): shape [(n, num_buckets *2), \
  509. (n, num_buckets *2)]
  510. img_meta (dict): Image meta info.
  511. Returns:
  512. Tensor: Regressed bboxes, the same shape as input rois.
  513. """
  514. assert rois.size(1) == 4 or rois.size(1) == 5
  515. if rois.size(1) == 4:
  516. new_rois, _ = self.bbox_coder.decode(rois, bbox_pred,
  517. img_meta['img_shape'])
  518. else:
  519. bboxes, _ = self.bbox_coder.decode(rois[:, 1:], bbox_pred,
  520. img_meta['img_shape'])
  521. new_rois = torch.cat((rois[:, [0]], bboxes), dim=1)
  522. return new_rois

No Description

Contributors (2)