You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

cascade_rpn_head.py 34 kB

2 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801
  1. # Copyright (c) OpenMMLab. All rights reserved.
  2. from __future__ import division
  3. import copy
  4. import warnings
  5. import torch
  6. import torch.nn as nn
  7. from mmcv import ConfigDict
  8. from mmcv.ops import DeformConv2d, batched_nms
  9. from mmcv.runner import BaseModule, ModuleList
  10. from mmdet.core import (RegionAssigner, build_assigner, build_sampler,
  11. images_to_levels, multi_apply)
  12. from mmdet.core.utils import select_single_mlvl
  13. from ..builder import HEADS, build_head
  14. from .base_dense_head import BaseDenseHead
  15. from .rpn_head import RPNHead
  16. class AdaptiveConv(BaseModule):
  17. """AdaptiveConv used to adapt the sampling location with the anchors.
  18. Args:
  19. in_channels (int): Number of channels in the input image
  20. out_channels (int): Number of channels produced by the convolution
  21. kernel_size (int or tuple): Size of the conv kernel. Default: 3
  22. stride (int or tuple, optional): Stride of the convolution. Default: 1
  23. padding (int or tuple, optional): Zero-padding added to both sides of
  24. the input. Default: 1
  25. dilation (int or tuple, optional): Spacing between kernel elements.
  26. Default: 3
  27. groups (int, optional): Number of blocked connections from input
  28. channels to output channels. Default: 1
  29. bias (bool, optional): If set True, adds a learnable bias to the
  30. output. Default: False.
  31. type (str, optional): Type of adaptive conv, can be either 'offset'
  32. (arbitrary anchors) or 'dilation' (uniform anchor).
  33. Default: 'dilation'.
  34. init_cfg (dict or list[dict], optional): Initialization config dict.
  35. """
  36. def __init__(self,
  37. in_channels,
  38. out_channels,
  39. kernel_size=3,
  40. stride=1,
  41. padding=1,
  42. dilation=3,
  43. groups=1,
  44. bias=False,
  45. type='dilation',
  46. init_cfg=dict(
  47. type='Normal', std=0.01, override=dict(name='conv'))):
  48. super(AdaptiveConv, self).__init__(init_cfg)
  49. assert type in ['offset', 'dilation']
  50. self.adapt_type = type
  51. assert kernel_size == 3, 'Adaptive conv only supports kernels 3'
  52. if self.adapt_type == 'offset':
  53. assert stride == 1 and padding == 1 and groups == 1, \
  54. 'Adaptive conv offset mode only supports padding: {1}, ' \
  55. f'stride: {1}, groups: {1}'
  56. self.conv = DeformConv2d(
  57. in_channels,
  58. out_channels,
  59. kernel_size,
  60. padding=padding,
  61. stride=stride,
  62. groups=groups,
  63. bias=bias)
  64. else:
  65. self.conv = nn.Conv2d(
  66. in_channels,
  67. out_channels,
  68. kernel_size,
  69. padding=dilation,
  70. dilation=dilation)
  71. def forward(self, x, offset):
  72. """Forward function."""
  73. if self.adapt_type == 'offset':
  74. N, _, H, W = x.shape
  75. assert offset is not None
  76. assert H * W == offset.shape[1]
  77. # reshape [N, NA, 18] to (N, 18, H, W)
  78. offset = offset.permute(0, 2, 1).reshape(N, -1, H, W)
  79. offset = offset.contiguous()
  80. x = self.conv(x, offset)
  81. else:
  82. assert offset is None
  83. x = self.conv(x)
  84. return x
  85. @HEADS.register_module()
  86. class StageCascadeRPNHead(RPNHead):
  87. """Stage of CascadeRPNHead.
  88. Args:
  89. in_channels (int): Number of channels in the input feature map.
  90. anchor_generator (dict): anchor generator config.
  91. adapt_cfg (dict): adaptation config.
  92. bridged_feature (bool, optional): whether update rpn feature.
  93. Default: False.
  94. with_cls (bool, optional): whether use classification branch.
  95. Default: True.
  96. sampling (bool, optional): whether use sampling. Default: True.
  97. init_cfg (dict or list[dict], optional): Initialization config dict.
  98. Default: None
  99. """
  100. def __init__(self,
  101. in_channels,
  102. anchor_generator=dict(
  103. type='AnchorGenerator',
  104. scales=[8],
  105. ratios=[1.0],
  106. strides=[4, 8, 16, 32, 64]),
  107. adapt_cfg=dict(type='dilation', dilation=3),
  108. bridged_feature=False,
  109. with_cls=True,
  110. sampling=True,
  111. init_cfg=None,
  112. **kwargs):
  113. self.with_cls = with_cls
  114. self.anchor_strides = anchor_generator['strides']
  115. self.anchor_scales = anchor_generator['scales']
  116. self.bridged_feature = bridged_feature
  117. self.adapt_cfg = adapt_cfg
  118. super(StageCascadeRPNHead, self).__init__(
  119. in_channels,
  120. anchor_generator=anchor_generator,
  121. init_cfg=init_cfg,
  122. **kwargs)
  123. # override sampling and sampler
  124. self.sampling = sampling
  125. if self.train_cfg:
  126. self.assigner = build_assigner(self.train_cfg.assigner)
  127. # use PseudoSampler when sampling is False
  128. if self.sampling and hasattr(self.train_cfg, 'sampler'):
  129. sampler_cfg = self.train_cfg.sampler
  130. else:
  131. sampler_cfg = dict(type='PseudoSampler')
  132. self.sampler = build_sampler(sampler_cfg, context=self)
  133. if init_cfg is None:
  134. self.init_cfg = dict(
  135. type='Normal', std=0.01, override=[dict(name='rpn_reg')])
  136. if self.with_cls:
  137. self.init_cfg['override'].append(dict(name='rpn_cls'))
  138. def _init_layers(self):
  139. """Init layers of a CascadeRPN stage."""
  140. self.rpn_conv = AdaptiveConv(self.in_channels, self.feat_channels,
  141. **self.adapt_cfg)
  142. if self.with_cls:
  143. self.rpn_cls = nn.Conv2d(self.feat_channels,
  144. self.num_anchors * self.cls_out_channels,
  145. 1)
  146. self.rpn_reg = nn.Conv2d(self.feat_channels, self.num_anchors * 4, 1)
  147. self.relu = nn.ReLU(inplace=True)
  148. def forward_single(self, x, offset):
  149. """Forward function of single scale."""
  150. bridged_x = x
  151. x = self.relu(self.rpn_conv(x, offset))
  152. if self.bridged_feature:
  153. bridged_x = x # update feature
  154. cls_score = self.rpn_cls(x) if self.with_cls else None
  155. bbox_pred = self.rpn_reg(x)
  156. return bridged_x, cls_score, bbox_pred
  157. def forward(self, feats, offset_list=None):
  158. """Forward function."""
  159. if offset_list is None:
  160. offset_list = [None for _ in range(len(feats))]
  161. return multi_apply(self.forward_single, feats, offset_list)
  162. def _region_targets_single(self,
  163. anchors,
  164. valid_flags,
  165. gt_bboxes,
  166. gt_bboxes_ignore,
  167. gt_labels,
  168. img_meta,
  169. featmap_sizes,
  170. label_channels=1):
  171. """Get anchor targets based on region for single level."""
  172. assign_result = self.assigner.assign(
  173. anchors,
  174. valid_flags,
  175. gt_bboxes,
  176. img_meta,
  177. featmap_sizes,
  178. self.anchor_scales[0],
  179. self.anchor_strides,
  180. gt_bboxes_ignore=gt_bboxes_ignore,
  181. gt_labels=None,
  182. allowed_border=self.train_cfg.allowed_border)
  183. flat_anchors = torch.cat(anchors)
  184. sampling_result = self.sampler.sample(assign_result, flat_anchors,
  185. gt_bboxes)
  186. num_anchors = flat_anchors.shape[0]
  187. bbox_targets = torch.zeros_like(flat_anchors)
  188. bbox_weights = torch.zeros_like(flat_anchors)
  189. labels = flat_anchors.new_zeros(num_anchors, dtype=torch.long)
  190. label_weights = flat_anchors.new_zeros(num_anchors, dtype=torch.float)
  191. pos_inds = sampling_result.pos_inds
  192. neg_inds = sampling_result.neg_inds
  193. if len(pos_inds) > 0:
  194. if not self.reg_decoded_bbox:
  195. pos_bbox_targets = self.bbox_coder.encode(
  196. sampling_result.pos_bboxes, sampling_result.pos_gt_bboxes)
  197. else:
  198. pos_bbox_targets = sampling_result.pos_gt_bboxes
  199. bbox_targets[pos_inds, :] = pos_bbox_targets
  200. bbox_weights[pos_inds, :] = 1.0
  201. if gt_labels is None:
  202. labels[pos_inds] = 1
  203. else:
  204. labels[pos_inds] = gt_labels[
  205. sampling_result.pos_assigned_gt_inds]
  206. if self.train_cfg.pos_weight <= 0:
  207. label_weights[pos_inds] = 1.0
  208. else:
  209. label_weights[pos_inds] = self.train_cfg.pos_weight
  210. if len(neg_inds) > 0:
  211. label_weights[neg_inds] = 1.0
  212. return (labels, label_weights, bbox_targets, bbox_weights, pos_inds,
  213. neg_inds)
  214. def region_targets(self,
  215. anchor_list,
  216. valid_flag_list,
  217. gt_bboxes_list,
  218. img_metas,
  219. featmap_sizes,
  220. gt_bboxes_ignore_list=None,
  221. gt_labels_list=None,
  222. label_channels=1,
  223. unmap_outputs=True):
  224. """See :func:`StageCascadeRPNHead.get_targets`."""
  225. num_imgs = len(img_metas)
  226. assert len(anchor_list) == len(valid_flag_list) == num_imgs
  227. # anchor number of multi levels
  228. num_level_anchors = [anchors.size(0) for anchors in anchor_list[0]]
  229. # compute targets for each image
  230. if gt_bboxes_ignore_list is None:
  231. gt_bboxes_ignore_list = [None for _ in range(num_imgs)]
  232. if gt_labels_list is None:
  233. gt_labels_list = [None for _ in range(num_imgs)]
  234. (all_labels, all_label_weights, all_bbox_targets, all_bbox_weights,
  235. pos_inds_list, neg_inds_list) = multi_apply(
  236. self._region_targets_single,
  237. anchor_list,
  238. valid_flag_list,
  239. gt_bboxes_list,
  240. gt_bboxes_ignore_list,
  241. gt_labels_list,
  242. img_metas,
  243. featmap_sizes=featmap_sizes,
  244. label_channels=label_channels)
  245. # no valid anchors
  246. if any([labels is None for labels in all_labels]):
  247. return None
  248. # sampled anchors of all images
  249. num_total_pos = sum([max(inds.numel(), 1) for inds in pos_inds_list])
  250. num_total_neg = sum([max(inds.numel(), 1) for inds in neg_inds_list])
  251. # split targets to a list w.r.t. multiple levels
  252. labels_list = images_to_levels(all_labels, num_level_anchors)
  253. label_weights_list = images_to_levels(all_label_weights,
  254. num_level_anchors)
  255. bbox_targets_list = images_to_levels(all_bbox_targets,
  256. num_level_anchors)
  257. bbox_weights_list = images_to_levels(all_bbox_weights,
  258. num_level_anchors)
  259. return (labels_list, label_weights_list, bbox_targets_list,
  260. bbox_weights_list, num_total_pos, num_total_neg)
  261. def get_targets(self,
  262. anchor_list,
  263. valid_flag_list,
  264. gt_bboxes,
  265. img_metas,
  266. featmap_sizes,
  267. gt_bboxes_ignore=None,
  268. label_channels=1):
  269. """Compute regression and classification targets for anchors.
  270. Args:
  271. anchor_list (list[list]): Multi level anchors of each image.
  272. valid_flag_list (list[list]): Multi level valid flags of each
  273. image.
  274. gt_bboxes (list[Tensor]): Ground truth bboxes of each image.
  275. img_metas (list[dict]): Meta info of each image.
  276. featmap_sizes (list[Tensor]): Feature mapsize each level
  277. gt_bboxes_ignore (list[Tensor]): Ignore bboxes of each images
  278. label_channels (int): Channel of label.
  279. Returns:
  280. cls_reg_targets (tuple)
  281. """
  282. if isinstance(self.assigner, RegionAssigner):
  283. cls_reg_targets = self.region_targets(
  284. anchor_list,
  285. valid_flag_list,
  286. gt_bboxes,
  287. img_metas,
  288. featmap_sizes,
  289. gt_bboxes_ignore_list=gt_bboxes_ignore,
  290. label_channels=label_channels)
  291. else:
  292. cls_reg_targets = super(StageCascadeRPNHead, self).get_targets(
  293. anchor_list,
  294. valid_flag_list,
  295. gt_bboxes,
  296. img_metas,
  297. gt_bboxes_ignore_list=gt_bboxes_ignore,
  298. label_channels=label_channels)
  299. return cls_reg_targets
  300. def anchor_offset(self, anchor_list, anchor_strides, featmap_sizes):
  301. """ Get offset for deformable conv based on anchor shape
  302. NOTE: currently support deformable kernel_size=3 and dilation=1
  303. Args:
  304. anchor_list (list[list[tensor])): [NI, NLVL, NA, 4] list of
  305. multi-level anchors
  306. anchor_strides (list[int]): anchor stride of each level
  307. Returns:
  308. offset_list (list[tensor]): [NLVL, NA, 2, 18]: offset of DeformConv
  309. kernel.
  310. """
  311. def _shape_offset(anchors, stride, ks=3, dilation=1):
  312. # currently support kernel_size=3 and dilation=1
  313. assert ks == 3 and dilation == 1
  314. pad = (ks - 1) // 2
  315. idx = torch.arange(-pad, pad + 1, dtype=dtype, device=device)
  316. yy, xx = torch.meshgrid(idx, idx) # return order matters
  317. xx = xx.reshape(-1)
  318. yy = yy.reshape(-1)
  319. w = (anchors[:, 2] - anchors[:, 0]) / stride
  320. h = (anchors[:, 3] - anchors[:, 1]) / stride
  321. w = w / (ks - 1) - dilation
  322. h = h / (ks - 1) - dilation
  323. offset_x = w[:, None] * xx # (NA, ks**2)
  324. offset_y = h[:, None] * yy # (NA, ks**2)
  325. return offset_x, offset_y
  326. def _ctr_offset(anchors, stride, featmap_size):
  327. feat_h, feat_w = featmap_size
  328. assert len(anchors) == feat_h * feat_w
  329. x = (anchors[:, 0] + anchors[:, 2]) * 0.5
  330. y = (anchors[:, 1] + anchors[:, 3]) * 0.5
  331. # compute centers on feature map
  332. x = x / stride
  333. y = y / stride
  334. # compute predefine centers
  335. xx = torch.arange(0, feat_w, device=anchors.device)
  336. yy = torch.arange(0, feat_h, device=anchors.device)
  337. yy, xx = torch.meshgrid(yy, xx)
  338. xx = xx.reshape(-1).type_as(x)
  339. yy = yy.reshape(-1).type_as(y)
  340. offset_x = x - xx # (NA, )
  341. offset_y = y - yy # (NA, )
  342. return offset_x, offset_y
  343. num_imgs = len(anchor_list)
  344. num_lvls = len(anchor_list[0])
  345. dtype = anchor_list[0][0].dtype
  346. device = anchor_list[0][0].device
  347. num_level_anchors = [anchors.size(0) for anchors in anchor_list[0]]
  348. offset_list = []
  349. for i in range(num_imgs):
  350. mlvl_offset = []
  351. for lvl in range(num_lvls):
  352. c_offset_x, c_offset_y = _ctr_offset(anchor_list[i][lvl],
  353. anchor_strides[lvl],
  354. featmap_sizes[lvl])
  355. s_offset_x, s_offset_y = _shape_offset(anchor_list[i][lvl],
  356. anchor_strides[lvl])
  357. # offset = ctr_offset + shape_offset
  358. offset_x = s_offset_x + c_offset_x[:, None]
  359. offset_y = s_offset_y + c_offset_y[:, None]
  360. # offset order (y0, x0, y1, x2, .., y8, x8, y9, x9)
  361. offset = torch.stack([offset_y, offset_x], dim=-1)
  362. offset = offset.reshape(offset.size(0), -1) # [NA, 2*ks**2]
  363. mlvl_offset.append(offset)
  364. offset_list.append(torch.cat(mlvl_offset)) # [totalNA, 2*ks**2]
  365. offset_list = images_to_levels(offset_list, num_level_anchors)
  366. return offset_list
  367. def loss_single(self, cls_score, bbox_pred, anchors, labels, label_weights,
  368. bbox_targets, bbox_weights, num_total_samples):
  369. """Loss function on single scale."""
  370. # classification loss
  371. if self.with_cls:
  372. labels = labels.reshape(-1)
  373. label_weights = label_weights.reshape(-1)
  374. cls_score = cls_score.permute(0, 2, 3,
  375. 1).reshape(-1, self.cls_out_channels)
  376. loss_cls = self.loss_cls(
  377. cls_score, labels, label_weights, avg_factor=num_total_samples)
  378. # regression loss
  379. bbox_targets = bbox_targets.reshape(-1, 4)
  380. bbox_weights = bbox_weights.reshape(-1, 4)
  381. bbox_pred = bbox_pred.permute(0, 2, 3, 1).reshape(-1, 4)
  382. if self.reg_decoded_bbox:
  383. # When the regression loss (e.g. `IouLoss`, `GIouLoss`)
  384. # is applied directly on the decoded bounding boxes, it
  385. # decodes the already encoded coordinates to absolute format.
  386. anchors = anchors.reshape(-1, 4)
  387. bbox_pred = self.bbox_coder.decode(anchors, bbox_pred)
  388. loss_reg = self.loss_bbox(
  389. bbox_pred,
  390. bbox_targets,
  391. bbox_weights,
  392. avg_factor=num_total_samples)
  393. if self.with_cls:
  394. return loss_cls, loss_reg
  395. return None, loss_reg
  396. def loss(self,
  397. anchor_list,
  398. valid_flag_list,
  399. cls_scores,
  400. bbox_preds,
  401. gt_bboxes,
  402. img_metas,
  403. gt_bboxes_ignore=None):
  404. """Compute losses of the head.
  405. Args:
  406. anchor_list (list[list]): Multi level anchors of each image.
  407. cls_scores (list[Tensor]): Box scores for each scale level
  408. Has shape (N, num_anchors * num_classes, H, W)
  409. bbox_preds (list[Tensor]): Box energies / deltas for each scale
  410. level with shape (N, num_anchors * 4, H, W)
  411. gt_bboxes (list[Tensor]): Ground truth bboxes for each image with
  412. shape (num_gts, 4) in [tl_x, tl_y, br_x, br_y] format.
  413. img_metas (list[dict]): Meta information of each image, e.g.,
  414. image size, scaling factor, etc.
  415. gt_bboxes_ignore (None | list[Tensor]): specify which bounding
  416. boxes can be ignored when computing the loss. Default: None
  417. Returns:
  418. dict[str, Tensor]: A dictionary of loss components.
  419. """
  420. featmap_sizes = [featmap.size()[-2:] for featmap in bbox_preds]
  421. label_channels = self.cls_out_channels if self.use_sigmoid_cls else 1
  422. cls_reg_targets = self.get_targets(
  423. anchor_list,
  424. valid_flag_list,
  425. gt_bboxes,
  426. img_metas,
  427. featmap_sizes,
  428. gt_bboxes_ignore=gt_bboxes_ignore,
  429. label_channels=label_channels)
  430. if cls_reg_targets is None:
  431. return None
  432. (labels_list, label_weights_list, bbox_targets_list, bbox_weights_list,
  433. num_total_pos, num_total_neg) = cls_reg_targets
  434. if self.sampling:
  435. num_total_samples = num_total_pos + num_total_neg
  436. else:
  437. # 200 is hard-coded average factor,
  438. # which follows guided anchoring.
  439. num_total_samples = sum([label.numel()
  440. for label in labels_list]) / 200.0
  441. # change per image, per level anchor_list to per_level, per_image
  442. mlvl_anchor_list = list(zip(*anchor_list))
  443. # concat mlvl_anchor_list
  444. mlvl_anchor_list = [
  445. torch.cat(anchors, dim=0) for anchors in mlvl_anchor_list
  446. ]
  447. losses = multi_apply(
  448. self.loss_single,
  449. cls_scores,
  450. bbox_preds,
  451. mlvl_anchor_list,
  452. labels_list,
  453. label_weights_list,
  454. bbox_targets_list,
  455. bbox_weights_list,
  456. num_total_samples=num_total_samples)
  457. if self.with_cls:
  458. return dict(loss_rpn_cls=losses[0], loss_rpn_reg=losses[1])
  459. return dict(loss_rpn_reg=losses[1])
  460. def get_bboxes(self,
  461. anchor_list,
  462. cls_scores,
  463. bbox_preds,
  464. img_metas,
  465. cfg,
  466. rescale=False):
  467. """Get proposal predict.
  468. Args:
  469. anchor_list (list[list]): Multi level anchors of each image.
  470. cls_scores (list[Tensor]): Classification scores for all
  471. scale levels, each is a 4D-tensor, has shape
  472. (batch_size, num_priors * num_classes, H, W).
  473. bbox_preds (list[Tensor]): Box energies / deltas for all
  474. scale levels, each is a 4D-tensor, has shape
  475. (batch_size, num_priors * 4, H, W).
  476. img_metas (list[dict], Optional): Image meta info. Default None.
  477. cfg (mmcv.Config, Optional): Test / postprocessing configuration,
  478. if None, test_cfg would be used.
  479. rescale (bool): If True, return boxes in original image space.
  480. Default: False.
  481. Returns:
  482. Tensor: Labeled boxes in shape (n, 5), where the first 4 columns
  483. are bounding box positions (tl_x, tl_y, br_x, br_y) and the
  484. 5-th column is a score between 0 and 1.
  485. """
  486. assert len(cls_scores) == len(bbox_preds)
  487. result_list = []
  488. for img_id in range(len(img_metas)):
  489. cls_score_list = select_single_mlvl(cls_scores, img_id)
  490. bbox_pred_list = select_single_mlvl(bbox_preds, img_id)
  491. img_shape = img_metas[img_id]['img_shape']
  492. scale_factor = img_metas[img_id]['scale_factor']
  493. proposals = self._get_bboxes_single(cls_score_list, bbox_pred_list,
  494. anchor_list[img_id], img_shape,
  495. scale_factor, cfg, rescale)
  496. result_list.append(proposals)
  497. return result_list
  498. def _get_bboxes_single(self,
  499. cls_scores,
  500. bbox_preds,
  501. mlvl_anchors,
  502. img_shape,
  503. scale_factor,
  504. cfg,
  505. rescale=False):
  506. """Transform outputs of a single image into bbox predictions.
  507. Args:
  508. cls_scores (list[Tensor]): Box scores from all scale
  509. levels of a single image, each item has shape
  510. (num_anchors * num_classes, H, W).
  511. bbox_preds (list[Tensor]): Box energies / deltas from
  512. all scale levels of a single image, each item has
  513. shape (num_anchors * 4, H, W).
  514. mlvl_anchors (list[Tensor]): Box reference from all scale
  515. levels of a single image, each item has shape
  516. (num_total_anchors, 4).
  517. img_shape (tuple[int]): Shape of the input image,
  518. (height, width, 3).
  519. scale_factor (ndarray): Scale factor of the image arange as
  520. (w_scale, h_scale, w_scale, h_scale).
  521. cfg (mmcv.Config): Test / postprocessing configuration,
  522. if None, test_cfg would be used.
  523. rescale (bool): If True, return boxes in original image space.
  524. Default False.
  525. Returns:
  526. Tensor: Labeled boxes in shape (n, 5), where the first 4 columns
  527. are bounding box positions (tl_x, tl_y, br_x, br_y) and the
  528. 5-th column is a score between 0 and 1.
  529. """
  530. cfg = self.test_cfg if cfg is None else cfg
  531. cfg = copy.deepcopy(cfg)
  532. # bboxes from different level should be independent during NMS,
  533. # level_ids are used as labels for batched NMS to separate them
  534. level_ids = []
  535. mlvl_scores = []
  536. mlvl_bbox_preds = []
  537. mlvl_valid_anchors = []
  538. nms_pre = cfg.get('nms_pre', -1)
  539. for idx in range(len(cls_scores)):
  540. rpn_cls_score = cls_scores[idx]
  541. rpn_bbox_pred = bbox_preds[idx]
  542. assert rpn_cls_score.size()[-2:] == rpn_bbox_pred.size()[-2:]
  543. rpn_cls_score = rpn_cls_score.permute(1, 2, 0)
  544. if self.use_sigmoid_cls:
  545. rpn_cls_score = rpn_cls_score.reshape(-1)
  546. scores = rpn_cls_score.sigmoid()
  547. else:
  548. rpn_cls_score = rpn_cls_score.reshape(-1, 2)
  549. # We set FG labels to [0, num_class-1] and BG label to
  550. # num_class in RPN head since mmdet v2.5, which is unified to
  551. # be consistent with other head since mmdet v2.0. In mmdet v2.0
  552. # to v2.4 we keep BG label as 0 and FG label as 1 in rpn head.
  553. scores = rpn_cls_score.softmax(dim=1)[:, 0]
  554. rpn_bbox_pred = rpn_bbox_pred.permute(1, 2, 0).reshape(-1, 4)
  555. anchors = mlvl_anchors[idx]
  556. if 0 < nms_pre < scores.shape[0]:
  557. # sort is faster than topk
  558. # _, topk_inds = scores.topk(cfg.nms_pre)
  559. ranked_scores, rank_inds = scores.sort(descending=True)
  560. topk_inds = rank_inds[:nms_pre]
  561. scores = ranked_scores[:nms_pre]
  562. rpn_bbox_pred = rpn_bbox_pred[topk_inds, :]
  563. anchors = anchors[topk_inds, :]
  564. mlvl_scores.append(scores)
  565. mlvl_bbox_preds.append(rpn_bbox_pred)
  566. mlvl_valid_anchors.append(anchors)
  567. level_ids.append(
  568. scores.new_full((scores.size(0), ), idx, dtype=torch.long))
  569. scores = torch.cat(mlvl_scores)
  570. anchors = torch.cat(mlvl_valid_anchors)
  571. rpn_bbox_pred = torch.cat(mlvl_bbox_preds)
  572. proposals = self.bbox_coder.decode(
  573. anchors, rpn_bbox_pred, max_shape=img_shape)
  574. ids = torch.cat(level_ids)
  575. if cfg.min_bbox_size >= 0:
  576. w = proposals[:, 2] - proposals[:, 0]
  577. h = proposals[:, 3] - proposals[:, 1]
  578. valid_mask = (w > cfg.min_bbox_size) & (h > cfg.min_bbox_size)
  579. if not valid_mask.all():
  580. proposals = proposals[valid_mask]
  581. scores = scores[valid_mask]
  582. ids = ids[valid_mask]
  583. # deprecate arguments warning
  584. if 'nms' not in cfg or 'max_num' in cfg or 'nms_thr' in cfg:
  585. warnings.warn(
  586. 'In rpn_proposal or test_cfg, '
  587. 'nms_thr has been moved to a dict named nms as '
  588. 'iou_threshold, max_num has been renamed as max_per_img, '
  589. 'name of original arguments and the way to specify '
  590. 'iou_threshold of NMS will be deprecated.')
  591. if 'nms' not in cfg:
  592. cfg.nms = ConfigDict(dict(type='nms', iou_threshold=cfg.nms_thr))
  593. if 'max_num' in cfg:
  594. if 'max_per_img' in cfg:
  595. assert cfg.max_num == cfg.max_per_img, f'You ' \
  596. f'set max_num and ' \
  597. f'max_per_img at the same time, but get {cfg.max_num} ' \
  598. f'and {cfg.max_per_img} respectively' \
  599. 'Please delete max_num which will be deprecated.'
  600. else:
  601. cfg.max_per_img = cfg.max_num
  602. if 'nms_thr' in cfg:
  603. assert cfg.nms.iou_threshold == cfg.nms_thr, f'You set' \
  604. f' iou_threshold in nms and ' \
  605. f'nms_thr at the same time, but get' \
  606. f' {cfg.nms.iou_threshold} and {cfg.nms_thr}' \
  607. f' respectively. Please delete the nms_thr ' \
  608. f'which will be deprecated.'
  609. if proposals.numel() > 0:
  610. dets, _ = batched_nms(proposals, scores, ids, cfg.nms)
  611. else:
  612. return proposals.new_zeros(0, 5)
  613. return dets[:cfg.max_per_img]
  614. def refine_bboxes(self, anchor_list, bbox_preds, img_metas):
  615. """Refine bboxes through stages."""
  616. num_levels = len(bbox_preds)
  617. new_anchor_list = []
  618. for img_id in range(len(img_metas)):
  619. mlvl_anchors = []
  620. for i in range(num_levels):
  621. bbox_pred = bbox_preds[i][img_id].detach()
  622. bbox_pred = bbox_pred.permute(1, 2, 0).reshape(-1, 4)
  623. img_shape = img_metas[img_id]['img_shape']
  624. bboxes = self.bbox_coder.decode(anchor_list[img_id][i],
  625. bbox_pred, img_shape)
  626. mlvl_anchors.append(bboxes)
  627. new_anchor_list.append(mlvl_anchors)
  628. return new_anchor_list
  629. @HEADS.register_module()
  630. class CascadeRPNHead(BaseDenseHead):
  631. """The CascadeRPNHead will predict more accurate region proposals, which is
  632. required for two-stage detectors (such as Fast/Faster R-CNN). CascadeRPN
  633. consists of a sequence of RPNStage to progressively improve the accuracy of
  634. the detected proposals.
  635. More details can be found in ``https://arxiv.org/abs/1909.06720``.
  636. Args:
  637. num_stages (int): number of CascadeRPN stages.
  638. stages (list[dict]): list of configs to build the stages.
  639. train_cfg (list[dict]): list of configs at training time each stage.
  640. test_cfg (dict): config at testing time.
  641. """
  642. def __init__(self, num_stages, stages, train_cfg, test_cfg, init_cfg=None):
  643. super(CascadeRPNHead, self).__init__(init_cfg)
  644. assert num_stages == len(stages)
  645. self.num_stages = num_stages
  646. # Be careful! Pretrained weights cannot be loaded when use
  647. # nn.ModuleList
  648. self.stages = ModuleList()
  649. for i in range(len(stages)):
  650. train_cfg_i = train_cfg[i] if train_cfg is not None else None
  651. stages[i].update(train_cfg=train_cfg_i)
  652. stages[i].update(test_cfg=test_cfg)
  653. self.stages.append(build_head(stages[i]))
  654. self.train_cfg = train_cfg
  655. self.test_cfg = test_cfg
  656. def loss(self):
  657. """loss() is implemented in StageCascadeRPNHead."""
  658. pass
  659. def get_bboxes(self):
  660. """get_bboxes() is implemented in StageCascadeRPNHead."""
  661. pass
  662. def forward_train(self,
  663. x,
  664. img_metas,
  665. gt_bboxes,
  666. gt_labels=None,
  667. gt_bboxes_ignore=None,
  668. proposal_cfg=None):
  669. """Forward train function."""
  670. assert gt_labels is None, 'RPN does not require gt_labels'
  671. featmap_sizes = [featmap.size()[-2:] for featmap in x]
  672. device = x[0].device
  673. anchor_list, valid_flag_list = self.stages[0].get_anchors(
  674. featmap_sizes, img_metas, device=device)
  675. losses = dict()
  676. for i in range(self.num_stages):
  677. stage = self.stages[i]
  678. if stage.adapt_cfg['type'] == 'offset':
  679. offset_list = stage.anchor_offset(anchor_list,
  680. stage.anchor_strides,
  681. featmap_sizes)
  682. else:
  683. offset_list = None
  684. x, cls_score, bbox_pred = stage(x, offset_list)
  685. rpn_loss_inputs = (anchor_list, valid_flag_list, cls_score,
  686. bbox_pred, gt_bboxes, img_metas)
  687. stage_loss = stage.loss(*rpn_loss_inputs)
  688. for name, value in stage_loss.items():
  689. losses['s{}.{}'.format(i, name)] = value
  690. # refine boxes
  691. if i < self.num_stages - 1:
  692. anchor_list = stage.refine_bboxes(anchor_list, bbox_pred,
  693. img_metas)
  694. if proposal_cfg is None:
  695. return losses
  696. else:
  697. proposal_list = self.stages[-1].get_bboxes(anchor_list, cls_score,
  698. bbox_pred, img_metas,
  699. self.test_cfg)
  700. return losses, proposal_list
  701. def simple_test_rpn(self, x, img_metas):
  702. """Simple forward test function."""
  703. featmap_sizes = [featmap.size()[-2:] for featmap in x]
  704. device = x[0].device
  705. anchor_list, _ = self.stages[0].get_anchors(
  706. featmap_sizes, img_metas, device=device)
  707. for i in range(self.num_stages):
  708. stage = self.stages[i]
  709. if stage.adapt_cfg['type'] == 'offset':
  710. offset_list = stage.anchor_offset(anchor_list,
  711. stage.anchor_strides,
  712. featmap_sizes)
  713. else:
  714. offset_list = None
  715. x, cls_score, bbox_pred = stage(x, offset_list)
  716. if i < self.num_stages - 1:
  717. anchor_list = stage.refine_bboxes(anchor_list, bbox_pred,
  718. img_metas)
  719. proposal_list = self.stages[-1].get_bboxes(anchor_list, cls_score,
  720. bbox_pred, img_metas,
  721. self.test_cfg)
  722. return proposal_list
  723. def aug_test_rpn(self, x, img_metas):
  724. """Augmented forward test function."""
  725. raise NotImplementedError(
  726. 'CascadeRPNHead does not support test-time augmentation')

No Description

Contributors (3)