You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

guided_anchor_head.py 37 kB

2 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868
  1. # Copyright (c) OpenMMLab. All rights reserved.
  2. import warnings
  3. import torch
  4. import torch.nn as nn
  5. from mmcv.ops import DeformConv2d, MaskedConv2d
  6. from mmcv.runner import BaseModule, force_fp32
  7. from mmdet.core import (anchor_inside_flags, build_assigner, build_bbox_coder,
  8. build_prior_generator, build_sampler, calc_region,
  9. images_to_levels, multi_apply, multiclass_nms, unmap)
  10. from ..builder import HEADS, build_loss
  11. from .anchor_head import AnchorHead
  12. class FeatureAdaption(BaseModule):
  13. """Feature Adaption Module.
  14. Feature Adaption Module is implemented based on DCN v1.
  15. It uses anchor shape prediction rather than feature map to
  16. predict offsets of deform conv layer.
  17. Args:
  18. in_channels (int): Number of channels in the input feature map.
  19. out_channels (int): Number of channels in the output feature map.
  20. kernel_size (int): Deformable conv kernel size.
  21. deform_groups (int): Deformable conv group size.
  22. init_cfg (dict or list[dict], optional): Initialization config dict.
  23. """
  24. def __init__(self,
  25. in_channels,
  26. out_channels,
  27. kernel_size=3,
  28. deform_groups=4,
  29. init_cfg=dict(
  30. type='Normal',
  31. layer='Conv2d',
  32. std=0.1,
  33. override=dict(
  34. type='Normal', name='conv_adaption', std=0.01))):
  35. super(FeatureAdaption, self).__init__(init_cfg)
  36. offset_channels = kernel_size * kernel_size * 2
  37. self.conv_offset = nn.Conv2d(
  38. 2, deform_groups * offset_channels, 1, bias=False)
  39. self.conv_adaption = DeformConv2d(
  40. in_channels,
  41. out_channels,
  42. kernel_size=kernel_size,
  43. padding=(kernel_size - 1) // 2,
  44. deform_groups=deform_groups)
  45. self.relu = nn.ReLU(inplace=True)
  46. def forward(self, x, shape):
  47. offset = self.conv_offset(shape.detach())
  48. x = self.relu(self.conv_adaption(x, offset))
  49. return x
  50. @HEADS.register_module()
  51. class GuidedAnchorHead(AnchorHead):
  52. """Guided-Anchor-based head (GA-RPN, GA-RetinaNet, etc.).
  53. This GuidedAnchorHead will predict high-quality feature guided
  54. anchors and locations where anchors will be kept in inference.
  55. There are mainly 3 categories of bounding-boxes.
  56. - Sampled 9 pairs for target assignment. (approxes)
  57. - The square boxes where the predicted anchors are based on. (squares)
  58. - Guided anchors.
  59. Please refer to https://arxiv.org/abs/1901.03278 for more details.
  60. Args:
  61. num_classes (int): Number of classes.
  62. in_channels (int): Number of channels in the input feature map.
  63. feat_channels (int): Number of hidden channels.
  64. approx_anchor_generator (dict): Config dict for approx generator
  65. square_anchor_generator (dict): Config dict for square generator
  66. anchor_coder (dict): Config dict for anchor coder
  67. bbox_coder (dict): Config dict for bbox coder
  68. reg_decoded_bbox (bool): If true, the regression loss would be
  69. applied directly on decoded bounding boxes, converting both
  70. the predicted boxes and regression targets to absolute
  71. coordinates format. Default False. It should be `True` when
  72. using `IoULoss`, `GIoULoss`, or `DIoULoss` in the bbox head.
  73. deform_groups: (int): Group number of DCN in
  74. FeatureAdaption module.
  75. loc_filter_thr (float): Threshold to filter out unconcerned regions.
  76. loss_loc (dict): Config of location loss.
  77. loss_shape (dict): Config of anchor shape loss.
  78. loss_cls (dict): Config of classification loss.
  79. loss_bbox (dict): Config of bbox regression loss.
  80. init_cfg (dict or list[dict], optional): Initialization config dict.
  81. """
  82. def __init__(
  83. self,
  84. num_classes,
  85. in_channels,
  86. feat_channels=256,
  87. approx_anchor_generator=dict(
  88. type='AnchorGenerator',
  89. octave_base_scale=8,
  90. scales_per_octave=3,
  91. ratios=[0.5, 1.0, 2.0],
  92. strides=[4, 8, 16, 32, 64]),
  93. square_anchor_generator=dict(
  94. type='AnchorGenerator',
  95. ratios=[1.0],
  96. scales=[8],
  97. strides=[4, 8, 16, 32, 64]),
  98. anchor_coder=dict(
  99. type='DeltaXYWHBBoxCoder',
  100. target_means=[.0, .0, .0, .0],
  101. target_stds=[1.0, 1.0, 1.0, 1.0]
  102. ),
  103. bbox_coder=dict(
  104. type='DeltaXYWHBBoxCoder',
  105. target_means=[.0, .0, .0, .0],
  106. target_stds=[1.0, 1.0, 1.0, 1.0]
  107. ),
  108. reg_decoded_bbox=False,
  109. deform_groups=4,
  110. loc_filter_thr=0.01,
  111. train_cfg=None,
  112. test_cfg=None,
  113. loss_loc=dict(
  114. type='FocalLoss',
  115. use_sigmoid=True,
  116. gamma=2.0,
  117. alpha=0.25,
  118. loss_weight=1.0),
  119. loss_shape=dict(type='BoundedIoULoss', beta=0.2, loss_weight=1.0),
  120. loss_cls=dict(
  121. type='CrossEntropyLoss', use_sigmoid=True, loss_weight=1.0),
  122. loss_bbox=dict(type='SmoothL1Loss', beta=1.0,
  123. loss_weight=1.0),
  124. init_cfg=dict(type='Normal', layer='Conv2d', std=0.01,
  125. override=dict(type='Normal',
  126. name='conv_loc',
  127. std=0.01,
  128. bias_prob=0.01))): # yapf: disable
  129. super(AnchorHead, self).__init__(init_cfg)
  130. self.in_channels = in_channels
  131. self.num_classes = num_classes
  132. self.feat_channels = feat_channels
  133. self.deform_groups = deform_groups
  134. self.loc_filter_thr = loc_filter_thr
  135. # build approx_anchor_generator and square_anchor_generator
  136. assert (approx_anchor_generator['octave_base_scale'] ==
  137. square_anchor_generator['scales'][0])
  138. assert (approx_anchor_generator['strides'] ==
  139. square_anchor_generator['strides'])
  140. self.approx_anchor_generator = build_prior_generator(
  141. approx_anchor_generator)
  142. self.square_anchor_generator = build_prior_generator(
  143. square_anchor_generator)
  144. self.approxs_per_octave = self.approx_anchor_generator \
  145. .num_base_priors[0]
  146. self.reg_decoded_bbox = reg_decoded_bbox
  147. # one anchor per location
  148. self.num_base_priors = self.square_anchor_generator.num_base_priors[0]
  149. self.use_sigmoid_cls = loss_cls.get('use_sigmoid', False)
  150. self.loc_focal_loss = loss_loc['type'] in ['FocalLoss']
  151. self.sampling = loss_cls['type'] not in ['FocalLoss']
  152. self.ga_sampling = train_cfg is not None and hasattr(
  153. train_cfg, 'ga_sampler')
  154. if self.use_sigmoid_cls:
  155. self.cls_out_channels = self.num_classes
  156. else:
  157. self.cls_out_channels = self.num_classes + 1
  158. # build bbox_coder
  159. self.anchor_coder = build_bbox_coder(anchor_coder)
  160. self.bbox_coder = build_bbox_coder(bbox_coder)
  161. # build losses
  162. self.loss_loc = build_loss(loss_loc)
  163. self.loss_shape = build_loss(loss_shape)
  164. self.loss_cls = build_loss(loss_cls)
  165. self.loss_bbox = build_loss(loss_bbox)
  166. self.train_cfg = train_cfg
  167. self.test_cfg = test_cfg
  168. if self.train_cfg:
  169. self.assigner = build_assigner(self.train_cfg.assigner)
  170. # use PseudoSampler when sampling is False
  171. if self.sampling and hasattr(self.train_cfg, 'sampler'):
  172. sampler_cfg = self.train_cfg.sampler
  173. else:
  174. sampler_cfg = dict(type='PseudoSampler')
  175. self.sampler = build_sampler(sampler_cfg, context=self)
  176. self.ga_assigner = build_assigner(self.train_cfg.ga_assigner)
  177. if self.ga_sampling:
  178. ga_sampler_cfg = self.train_cfg.ga_sampler
  179. else:
  180. ga_sampler_cfg = dict(type='PseudoSampler')
  181. self.ga_sampler = build_sampler(ga_sampler_cfg, context=self)
  182. self.fp16_enabled = False
  183. self._init_layers()
  184. @property
  185. def num_anchors(self):
  186. warnings.warn('DeprecationWarning: `num_anchors` is deprecated, '
  187. 'please use "num_base_priors" instead')
  188. return self.square_anchor_generator.num_base_priors[0]
  189. def _init_layers(self):
  190. self.relu = nn.ReLU(inplace=True)
  191. self.conv_loc = nn.Conv2d(self.in_channels, 1, 1)
  192. self.conv_shape = nn.Conv2d(self.in_channels, self.num_base_priors * 2,
  193. 1)
  194. self.feature_adaption = FeatureAdaption(
  195. self.in_channels,
  196. self.feat_channels,
  197. kernel_size=3,
  198. deform_groups=self.deform_groups)
  199. self.conv_cls = MaskedConv2d(
  200. self.feat_channels, self.num_base_priors * self.cls_out_channels,
  201. 1)
  202. self.conv_reg = MaskedConv2d(self.feat_channels,
  203. self.num_base_priors * 4, 1)
  204. def forward_single(self, x):
  205. loc_pred = self.conv_loc(x)
  206. shape_pred = self.conv_shape(x)
  207. x = self.feature_adaption(x, shape_pred)
  208. # masked conv is only used during inference for speed-up
  209. if not self.training:
  210. mask = loc_pred.sigmoid()[0] >= self.loc_filter_thr
  211. else:
  212. mask = None
  213. cls_score = self.conv_cls(x, mask)
  214. bbox_pred = self.conv_reg(x, mask)
  215. return cls_score, bbox_pred, shape_pred, loc_pred
  216. def forward(self, feats):
  217. return multi_apply(self.forward_single, feats)
  218. def get_sampled_approxs(self, featmap_sizes, img_metas, device='cuda'):
  219. """Get sampled approxs and inside flags according to feature map sizes.
  220. Args:
  221. featmap_sizes (list[tuple]): Multi-level feature map sizes.
  222. img_metas (list[dict]): Image meta info.
  223. device (torch.device | str): device for returned tensors
  224. Returns:
  225. tuple: approxes of each image, inside flags of each image
  226. """
  227. num_imgs = len(img_metas)
  228. # since feature map sizes of all images are the same, we only compute
  229. # approxes for one time
  230. multi_level_approxs = self.approx_anchor_generator.grid_priors(
  231. featmap_sizes, device=device)
  232. approxs_list = [multi_level_approxs for _ in range(num_imgs)]
  233. # for each image, we compute inside flags of multi level approxes
  234. inside_flag_list = []
  235. for img_id, img_meta in enumerate(img_metas):
  236. multi_level_flags = []
  237. multi_level_approxs = approxs_list[img_id]
  238. # obtain valid flags for each approx first
  239. multi_level_approx_flags = self.approx_anchor_generator \
  240. .valid_flags(featmap_sizes,
  241. img_meta['pad_shape'],
  242. device=device)
  243. for i, flags in enumerate(multi_level_approx_flags):
  244. approxs = multi_level_approxs[i]
  245. inside_flags_list = []
  246. for i in range(self.approxs_per_octave):
  247. split_valid_flags = flags[i::self.approxs_per_octave]
  248. split_approxs = approxs[i::self.approxs_per_octave, :]
  249. inside_flags = anchor_inside_flags(
  250. split_approxs, split_valid_flags,
  251. img_meta['img_shape'][:2],
  252. self.train_cfg.allowed_border)
  253. inside_flags_list.append(inside_flags)
  254. # inside_flag for a position is true if any anchor in this
  255. # position is true
  256. inside_flags = (
  257. torch.stack(inside_flags_list, 0).sum(dim=0) > 0)
  258. multi_level_flags.append(inside_flags)
  259. inside_flag_list.append(multi_level_flags)
  260. return approxs_list, inside_flag_list
  261. def get_anchors(self,
  262. featmap_sizes,
  263. shape_preds,
  264. loc_preds,
  265. img_metas,
  266. use_loc_filter=False,
  267. device='cuda'):
  268. """Get squares according to feature map sizes and guided anchors.
  269. Args:
  270. featmap_sizes (list[tuple]): Multi-level feature map sizes.
  271. shape_preds (list[tensor]): Multi-level shape predictions.
  272. loc_preds (list[tensor]): Multi-level location predictions.
  273. img_metas (list[dict]): Image meta info.
  274. use_loc_filter (bool): Use loc filter or not.
  275. device (torch.device | str): device for returned tensors
  276. Returns:
  277. tuple: square approxs of each image, guided anchors of each image,
  278. loc masks of each image
  279. """
  280. num_imgs = len(img_metas)
  281. num_levels = len(featmap_sizes)
  282. # since feature map sizes of all images are the same, we only compute
  283. # squares for one time
  284. multi_level_squares = self.square_anchor_generator.grid_priors(
  285. featmap_sizes, device=device)
  286. squares_list = [multi_level_squares for _ in range(num_imgs)]
  287. # for each image, we compute multi level guided anchors
  288. guided_anchors_list = []
  289. loc_mask_list = []
  290. for img_id, img_meta in enumerate(img_metas):
  291. multi_level_guided_anchors = []
  292. multi_level_loc_mask = []
  293. for i in range(num_levels):
  294. squares = squares_list[img_id][i]
  295. shape_pred = shape_preds[i][img_id]
  296. loc_pred = loc_preds[i][img_id]
  297. guided_anchors, loc_mask = self._get_guided_anchors_single(
  298. squares,
  299. shape_pred,
  300. loc_pred,
  301. use_loc_filter=use_loc_filter)
  302. multi_level_guided_anchors.append(guided_anchors)
  303. multi_level_loc_mask.append(loc_mask)
  304. guided_anchors_list.append(multi_level_guided_anchors)
  305. loc_mask_list.append(multi_level_loc_mask)
  306. return squares_list, guided_anchors_list, loc_mask_list
  307. def _get_guided_anchors_single(self,
  308. squares,
  309. shape_pred,
  310. loc_pred,
  311. use_loc_filter=False):
  312. """Get guided anchors and loc masks for a single level.
  313. Args:
  314. square (tensor): Squares of a single level.
  315. shape_pred (tensor): Shape predictions of a single level.
  316. loc_pred (tensor): Loc predictions of a single level.
  317. use_loc_filter (list[tensor]): Use loc filter or not.
  318. Returns:
  319. tuple: guided anchors, location masks
  320. """
  321. # calculate location filtering mask
  322. loc_pred = loc_pred.sigmoid().detach()
  323. if use_loc_filter:
  324. loc_mask = loc_pred >= self.loc_filter_thr
  325. else:
  326. loc_mask = loc_pred >= 0.0
  327. mask = loc_mask.permute(1, 2, 0).expand(-1, -1, self.num_base_priors)
  328. mask = mask.contiguous().view(-1)
  329. # calculate guided anchors
  330. squares = squares[mask]
  331. anchor_deltas = shape_pred.permute(1, 2, 0).contiguous().view(
  332. -1, 2).detach()[mask]
  333. bbox_deltas = anchor_deltas.new_full(squares.size(), 0)
  334. bbox_deltas[:, 2:] = anchor_deltas
  335. guided_anchors = self.anchor_coder.decode(
  336. squares, bbox_deltas, wh_ratio_clip=1e-6)
  337. return guided_anchors, mask
  338. def ga_loc_targets(self, gt_bboxes_list, featmap_sizes):
  339. """Compute location targets for guided anchoring.
  340. Each feature map is divided into positive, negative and ignore regions.
  341. - positive regions: target 1, weight 1
  342. - ignore regions: target 0, weight 0
  343. - negative regions: target 0, weight 0.1
  344. Args:
  345. gt_bboxes_list (list[Tensor]): Gt bboxes of each image.
  346. featmap_sizes (list[tuple]): Multi level sizes of each feature
  347. maps.
  348. Returns:
  349. tuple
  350. """
  351. anchor_scale = self.approx_anchor_generator.octave_base_scale
  352. anchor_strides = self.approx_anchor_generator.strides
  353. # Currently only supports same stride in x and y direction.
  354. for stride in anchor_strides:
  355. assert (stride[0] == stride[1])
  356. anchor_strides = [stride[0] for stride in anchor_strides]
  357. center_ratio = self.train_cfg.center_ratio
  358. ignore_ratio = self.train_cfg.ignore_ratio
  359. img_per_gpu = len(gt_bboxes_list)
  360. num_lvls = len(featmap_sizes)
  361. r1 = (1 - center_ratio) / 2
  362. r2 = (1 - ignore_ratio) / 2
  363. all_loc_targets = []
  364. all_loc_weights = []
  365. all_ignore_map = []
  366. for lvl_id in range(num_lvls):
  367. h, w = featmap_sizes[lvl_id]
  368. loc_targets = torch.zeros(
  369. img_per_gpu,
  370. 1,
  371. h,
  372. w,
  373. device=gt_bboxes_list[0].device,
  374. dtype=torch.float32)
  375. loc_weights = torch.full_like(loc_targets, -1)
  376. ignore_map = torch.zeros_like(loc_targets)
  377. all_loc_targets.append(loc_targets)
  378. all_loc_weights.append(loc_weights)
  379. all_ignore_map.append(ignore_map)
  380. for img_id in range(img_per_gpu):
  381. gt_bboxes = gt_bboxes_list[img_id]
  382. scale = torch.sqrt((gt_bboxes[:, 2] - gt_bboxes[:, 0]) *
  383. (gt_bboxes[:, 3] - gt_bboxes[:, 1]))
  384. min_anchor_size = scale.new_full(
  385. (1, ), float(anchor_scale * anchor_strides[0]))
  386. # assign gt bboxes to different feature levels w.r.t. their scales
  387. target_lvls = torch.floor(
  388. torch.log2(scale) - torch.log2(min_anchor_size) + 0.5)
  389. target_lvls = target_lvls.clamp(min=0, max=num_lvls - 1).long()
  390. for gt_id in range(gt_bboxes.size(0)):
  391. lvl = target_lvls[gt_id].item()
  392. # rescaled to corresponding feature map
  393. gt_ = gt_bboxes[gt_id, :4] / anchor_strides[lvl]
  394. # calculate ignore regions
  395. ignore_x1, ignore_y1, ignore_x2, ignore_y2 = calc_region(
  396. gt_, r2, featmap_sizes[lvl])
  397. # calculate positive (center) regions
  398. ctr_x1, ctr_y1, ctr_x2, ctr_y2 = calc_region(
  399. gt_, r1, featmap_sizes[lvl])
  400. all_loc_targets[lvl][img_id, 0, ctr_y1:ctr_y2 + 1,
  401. ctr_x1:ctr_x2 + 1] = 1
  402. all_loc_weights[lvl][img_id, 0, ignore_y1:ignore_y2 + 1,
  403. ignore_x1:ignore_x2 + 1] = 0
  404. all_loc_weights[lvl][img_id, 0, ctr_y1:ctr_y2 + 1,
  405. ctr_x1:ctr_x2 + 1] = 1
  406. # calculate ignore map on nearby low level feature
  407. if lvl > 0:
  408. d_lvl = lvl - 1
  409. # rescaled to corresponding feature map
  410. gt_ = gt_bboxes[gt_id, :4] / anchor_strides[d_lvl]
  411. ignore_x1, ignore_y1, ignore_x2, ignore_y2 = calc_region(
  412. gt_, r2, featmap_sizes[d_lvl])
  413. all_ignore_map[d_lvl][img_id, 0, ignore_y1:ignore_y2 + 1,
  414. ignore_x1:ignore_x2 + 1] = 1
  415. # calculate ignore map on nearby high level feature
  416. if lvl < num_lvls - 1:
  417. u_lvl = lvl + 1
  418. # rescaled to corresponding feature map
  419. gt_ = gt_bboxes[gt_id, :4] / anchor_strides[u_lvl]
  420. ignore_x1, ignore_y1, ignore_x2, ignore_y2 = calc_region(
  421. gt_, r2, featmap_sizes[u_lvl])
  422. all_ignore_map[u_lvl][img_id, 0, ignore_y1:ignore_y2 + 1,
  423. ignore_x1:ignore_x2 + 1] = 1
  424. for lvl_id in range(num_lvls):
  425. # ignore negative regions w.r.t. ignore map
  426. all_loc_weights[lvl_id][(all_loc_weights[lvl_id] < 0)
  427. & (all_ignore_map[lvl_id] > 0)] = 0
  428. # set negative regions with weight 0.1
  429. all_loc_weights[lvl_id][all_loc_weights[lvl_id] < 0] = 0.1
  430. # loc average factor to balance loss
  431. loc_avg_factor = sum(
  432. [t.size(0) * t.size(-1) * t.size(-2)
  433. for t in all_loc_targets]) / 200
  434. return all_loc_targets, all_loc_weights, loc_avg_factor
  435. def _ga_shape_target_single(self,
  436. flat_approxs,
  437. inside_flags,
  438. flat_squares,
  439. gt_bboxes,
  440. gt_bboxes_ignore,
  441. img_meta,
  442. unmap_outputs=True):
  443. """Compute guided anchoring targets.
  444. This function returns sampled anchors and gt bboxes directly
  445. rather than calculates regression targets.
  446. Args:
  447. flat_approxs (Tensor): flat approxs of a single image,
  448. shape (n, 4)
  449. inside_flags (Tensor): inside flags of a single image,
  450. shape (n, ).
  451. flat_squares (Tensor): flat squares of a single image,
  452. shape (approxs_per_octave * n, 4)
  453. gt_bboxes (Tensor): Ground truth bboxes of a single image.
  454. img_meta (dict): Meta info of a single image.
  455. approxs_per_octave (int): number of approxs per octave
  456. cfg (dict): RPN train configs.
  457. unmap_outputs (bool): unmap outputs or not.
  458. Returns:
  459. tuple
  460. """
  461. if not inside_flags.any():
  462. return (None, ) * 5
  463. # assign gt and sample anchors
  464. expand_inside_flags = inside_flags[:, None].expand(
  465. -1, self.approxs_per_octave).reshape(-1)
  466. approxs = flat_approxs[expand_inside_flags, :]
  467. squares = flat_squares[inside_flags, :]
  468. assign_result = self.ga_assigner.assign(approxs, squares,
  469. self.approxs_per_octave,
  470. gt_bboxes, gt_bboxes_ignore)
  471. sampling_result = self.ga_sampler.sample(assign_result, squares,
  472. gt_bboxes)
  473. bbox_anchors = torch.zeros_like(squares)
  474. bbox_gts = torch.zeros_like(squares)
  475. bbox_weights = torch.zeros_like(squares)
  476. pos_inds = sampling_result.pos_inds
  477. neg_inds = sampling_result.neg_inds
  478. if len(pos_inds) > 0:
  479. bbox_anchors[pos_inds, :] = sampling_result.pos_bboxes
  480. bbox_gts[pos_inds, :] = sampling_result.pos_gt_bboxes
  481. bbox_weights[pos_inds, :] = 1.0
  482. # map up to original set of anchors
  483. if unmap_outputs:
  484. num_total_anchors = flat_squares.size(0)
  485. bbox_anchors = unmap(bbox_anchors, num_total_anchors, inside_flags)
  486. bbox_gts = unmap(bbox_gts, num_total_anchors, inside_flags)
  487. bbox_weights = unmap(bbox_weights, num_total_anchors, inside_flags)
  488. return (bbox_anchors, bbox_gts, bbox_weights, pos_inds, neg_inds)
  489. def ga_shape_targets(self,
  490. approx_list,
  491. inside_flag_list,
  492. square_list,
  493. gt_bboxes_list,
  494. img_metas,
  495. gt_bboxes_ignore_list=None,
  496. unmap_outputs=True):
  497. """Compute guided anchoring targets.
  498. Args:
  499. approx_list (list[list]): Multi level approxs of each image.
  500. inside_flag_list (list[list]): Multi level inside flags of each
  501. image.
  502. square_list (list[list]): Multi level squares of each image.
  503. gt_bboxes_list (list[Tensor]): Ground truth bboxes of each image.
  504. img_metas (list[dict]): Meta info of each image.
  505. gt_bboxes_ignore_list (list[Tensor]): ignore list of gt bboxes.
  506. unmap_outputs (bool): unmap outputs or not.
  507. Returns:
  508. tuple
  509. """
  510. num_imgs = len(img_metas)
  511. assert len(approx_list) == len(inside_flag_list) == len(
  512. square_list) == num_imgs
  513. # anchor number of multi levels
  514. num_level_squares = [squares.size(0) for squares in square_list[0]]
  515. # concat all level anchors and flags to a single tensor
  516. inside_flag_flat_list = []
  517. approx_flat_list = []
  518. square_flat_list = []
  519. for i in range(num_imgs):
  520. assert len(square_list[i]) == len(inside_flag_list[i])
  521. inside_flag_flat_list.append(torch.cat(inside_flag_list[i]))
  522. approx_flat_list.append(torch.cat(approx_list[i]))
  523. square_flat_list.append(torch.cat(square_list[i]))
  524. # compute targets for each image
  525. if gt_bboxes_ignore_list is None:
  526. gt_bboxes_ignore_list = [None for _ in range(num_imgs)]
  527. (all_bbox_anchors, all_bbox_gts, all_bbox_weights, pos_inds_list,
  528. neg_inds_list) = multi_apply(
  529. self._ga_shape_target_single,
  530. approx_flat_list,
  531. inside_flag_flat_list,
  532. square_flat_list,
  533. gt_bboxes_list,
  534. gt_bboxes_ignore_list,
  535. img_metas,
  536. unmap_outputs=unmap_outputs)
  537. # no valid anchors
  538. if any([bbox_anchors is None for bbox_anchors in all_bbox_anchors]):
  539. return None
  540. # sampled anchors of all images
  541. num_total_pos = sum([max(inds.numel(), 1) for inds in pos_inds_list])
  542. num_total_neg = sum([max(inds.numel(), 1) for inds in neg_inds_list])
  543. # split targets to a list w.r.t. multiple levels
  544. bbox_anchors_list = images_to_levels(all_bbox_anchors,
  545. num_level_squares)
  546. bbox_gts_list = images_to_levels(all_bbox_gts, num_level_squares)
  547. bbox_weights_list = images_to_levels(all_bbox_weights,
  548. num_level_squares)
  549. return (bbox_anchors_list, bbox_gts_list, bbox_weights_list,
  550. num_total_pos, num_total_neg)
  551. def loss_shape_single(self, shape_pred, bbox_anchors, bbox_gts,
  552. anchor_weights, anchor_total_num):
  553. shape_pred = shape_pred.permute(0, 2, 3, 1).contiguous().view(-1, 2)
  554. bbox_anchors = bbox_anchors.contiguous().view(-1, 4)
  555. bbox_gts = bbox_gts.contiguous().view(-1, 4)
  556. anchor_weights = anchor_weights.contiguous().view(-1, 4)
  557. bbox_deltas = bbox_anchors.new_full(bbox_anchors.size(), 0)
  558. bbox_deltas[:, 2:] += shape_pred
  559. # filter out negative samples to speed-up weighted_bounded_iou_loss
  560. inds = torch.nonzero(
  561. anchor_weights[:, 0] > 0, as_tuple=False).squeeze(1)
  562. bbox_deltas_ = bbox_deltas[inds]
  563. bbox_anchors_ = bbox_anchors[inds]
  564. bbox_gts_ = bbox_gts[inds]
  565. anchor_weights_ = anchor_weights[inds]
  566. pred_anchors_ = self.anchor_coder.decode(
  567. bbox_anchors_, bbox_deltas_, wh_ratio_clip=1e-6)
  568. loss_shape = self.loss_shape(
  569. pred_anchors_,
  570. bbox_gts_,
  571. anchor_weights_,
  572. avg_factor=anchor_total_num)
  573. return loss_shape
  574. def loss_loc_single(self, loc_pred, loc_target, loc_weight,
  575. loc_avg_factor):
  576. loss_loc = self.loss_loc(
  577. loc_pred.reshape(-1, 1),
  578. loc_target.reshape(-1).long(),
  579. loc_weight.reshape(-1),
  580. avg_factor=loc_avg_factor)
  581. return loss_loc
  582. @force_fp32(
  583. apply_to=('cls_scores', 'bbox_preds', 'shape_preds', 'loc_preds'))
  584. def loss(self,
  585. cls_scores,
  586. bbox_preds,
  587. shape_preds,
  588. loc_preds,
  589. gt_bboxes,
  590. gt_labels,
  591. img_metas,
  592. gt_bboxes_ignore=None):
  593. featmap_sizes = [featmap.size()[-2:] for featmap in cls_scores]
  594. assert len(featmap_sizes) == self.approx_anchor_generator.num_levels
  595. device = cls_scores[0].device
  596. # get loc targets
  597. loc_targets, loc_weights, loc_avg_factor = self.ga_loc_targets(
  598. gt_bboxes, featmap_sizes)
  599. # get sampled approxes
  600. approxs_list, inside_flag_list = self.get_sampled_approxs(
  601. featmap_sizes, img_metas, device=device)
  602. # get squares and guided anchors
  603. squares_list, guided_anchors_list, _ = self.get_anchors(
  604. featmap_sizes, shape_preds, loc_preds, img_metas, device=device)
  605. # get shape targets
  606. shape_targets = self.ga_shape_targets(approxs_list, inside_flag_list,
  607. squares_list, gt_bboxes,
  608. img_metas)
  609. if shape_targets is None:
  610. return None
  611. (bbox_anchors_list, bbox_gts_list, anchor_weights_list, anchor_fg_num,
  612. anchor_bg_num) = shape_targets
  613. anchor_total_num = (
  614. anchor_fg_num if not self.ga_sampling else anchor_fg_num +
  615. anchor_bg_num)
  616. # get anchor targets
  617. label_channels = self.cls_out_channels if self.use_sigmoid_cls else 1
  618. cls_reg_targets = self.get_targets(
  619. guided_anchors_list,
  620. inside_flag_list,
  621. gt_bboxes,
  622. img_metas,
  623. gt_bboxes_ignore_list=gt_bboxes_ignore,
  624. gt_labels_list=gt_labels,
  625. label_channels=label_channels)
  626. if cls_reg_targets is None:
  627. return None
  628. (labels_list, label_weights_list, bbox_targets_list, bbox_weights_list,
  629. num_total_pos, num_total_neg) = cls_reg_targets
  630. num_total_samples = (
  631. num_total_pos + num_total_neg if self.sampling else num_total_pos)
  632. # anchor number of multi levels
  633. num_level_anchors = [
  634. anchors.size(0) for anchors in guided_anchors_list[0]
  635. ]
  636. # concat all level anchors to a single tensor
  637. concat_anchor_list = []
  638. for i in range(len(guided_anchors_list)):
  639. concat_anchor_list.append(torch.cat(guided_anchors_list[i]))
  640. all_anchor_list = images_to_levels(concat_anchor_list,
  641. num_level_anchors)
  642. # get classification and bbox regression losses
  643. losses_cls, losses_bbox = multi_apply(
  644. self.loss_single,
  645. cls_scores,
  646. bbox_preds,
  647. all_anchor_list,
  648. labels_list,
  649. label_weights_list,
  650. bbox_targets_list,
  651. bbox_weights_list,
  652. num_total_samples=num_total_samples)
  653. # get anchor location loss
  654. losses_loc = []
  655. for i in range(len(loc_preds)):
  656. loss_loc = self.loss_loc_single(
  657. loc_preds[i],
  658. loc_targets[i],
  659. loc_weights[i],
  660. loc_avg_factor=loc_avg_factor)
  661. losses_loc.append(loss_loc)
  662. # get anchor shape loss
  663. losses_shape = []
  664. for i in range(len(shape_preds)):
  665. loss_shape = self.loss_shape_single(
  666. shape_preds[i],
  667. bbox_anchors_list[i],
  668. bbox_gts_list[i],
  669. anchor_weights_list[i],
  670. anchor_total_num=anchor_total_num)
  671. losses_shape.append(loss_shape)
  672. return dict(
  673. loss_cls=losses_cls,
  674. loss_bbox=losses_bbox,
  675. loss_shape=losses_shape,
  676. loss_loc=losses_loc)
  677. @force_fp32(
  678. apply_to=('cls_scores', 'bbox_preds', 'shape_preds', 'loc_preds'))
  679. def get_bboxes(self,
  680. cls_scores,
  681. bbox_preds,
  682. shape_preds,
  683. loc_preds,
  684. img_metas,
  685. cfg=None,
  686. rescale=False):
  687. assert len(cls_scores) == len(bbox_preds) == len(shape_preds) == len(
  688. loc_preds)
  689. num_levels = len(cls_scores)
  690. featmap_sizes = [featmap.size()[-2:] for featmap in cls_scores]
  691. device = cls_scores[0].device
  692. # get guided anchors
  693. _, guided_anchors, loc_masks = self.get_anchors(
  694. featmap_sizes,
  695. shape_preds,
  696. loc_preds,
  697. img_metas,
  698. use_loc_filter=not self.training,
  699. device=device)
  700. result_list = []
  701. for img_id in range(len(img_metas)):
  702. cls_score_list = [
  703. cls_scores[i][img_id].detach() for i in range(num_levels)
  704. ]
  705. bbox_pred_list = [
  706. bbox_preds[i][img_id].detach() for i in range(num_levels)
  707. ]
  708. guided_anchor_list = [
  709. guided_anchors[img_id][i].detach() for i in range(num_levels)
  710. ]
  711. loc_mask_list = [
  712. loc_masks[img_id][i].detach() for i in range(num_levels)
  713. ]
  714. img_shape = img_metas[img_id]['img_shape']
  715. scale_factor = img_metas[img_id]['scale_factor']
  716. proposals = self._get_bboxes_single(cls_score_list, bbox_pred_list,
  717. guided_anchor_list,
  718. loc_mask_list, img_shape,
  719. scale_factor, cfg, rescale)
  720. result_list.append(proposals)
  721. return result_list
  722. def _get_bboxes_single(self,
  723. cls_scores,
  724. bbox_preds,
  725. mlvl_anchors,
  726. mlvl_masks,
  727. img_shape,
  728. scale_factor,
  729. cfg,
  730. rescale=False):
  731. cfg = self.test_cfg if cfg is None else cfg
  732. assert len(cls_scores) == len(bbox_preds) == len(mlvl_anchors)
  733. mlvl_bboxes = []
  734. mlvl_scores = []
  735. for cls_score, bbox_pred, anchors, mask in zip(cls_scores, bbox_preds,
  736. mlvl_anchors,
  737. mlvl_masks):
  738. assert cls_score.size()[-2:] == bbox_pred.size()[-2:]
  739. # if no location is kept, end.
  740. if mask.sum() == 0:
  741. continue
  742. # reshape scores and bbox_pred
  743. cls_score = cls_score.permute(1, 2,
  744. 0).reshape(-1, self.cls_out_channels)
  745. if self.use_sigmoid_cls:
  746. scores = cls_score.sigmoid()
  747. else:
  748. scores = cls_score.softmax(-1)
  749. bbox_pred = bbox_pred.permute(1, 2, 0).reshape(-1, 4)
  750. # filter scores, bbox_pred w.r.t. mask.
  751. # anchors are filtered in get_anchors() beforehand.
  752. scores = scores[mask, :]
  753. bbox_pred = bbox_pred[mask, :]
  754. if scores.dim() == 0:
  755. anchors = anchors.unsqueeze(0)
  756. scores = scores.unsqueeze(0)
  757. bbox_pred = bbox_pred.unsqueeze(0)
  758. # filter anchors, bbox_pred, scores w.r.t. scores
  759. nms_pre = cfg.get('nms_pre', -1)
  760. if nms_pre > 0 and scores.shape[0] > nms_pre:
  761. if self.use_sigmoid_cls:
  762. max_scores, _ = scores.max(dim=1)
  763. else:
  764. # remind that we set FG labels to [0, num_class-1]
  765. # since mmdet v2.0
  766. # BG cat_id: num_class
  767. max_scores, _ = scores[:, :-1].max(dim=1)
  768. _, topk_inds = max_scores.topk(nms_pre)
  769. anchors = anchors[topk_inds, :]
  770. bbox_pred = bbox_pred[topk_inds, :]
  771. scores = scores[topk_inds, :]
  772. bboxes = self.bbox_coder.decode(
  773. anchors, bbox_pred, max_shape=img_shape)
  774. mlvl_bboxes.append(bboxes)
  775. mlvl_scores.append(scores)
  776. mlvl_bboxes = torch.cat(mlvl_bboxes)
  777. if rescale:
  778. mlvl_bboxes /= mlvl_bboxes.new_tensor(scale_factor)
  779. mlvl_scores = torch.cat(mlvl_scores)
  780. if self.use_sigmoid_cls:
  781. # Add a dummy background class to the backend when using sigmoid
  782. # remind that we set FG labels to [0, num_class-1] since mmdet v2.0
  783. # BG cat_id: num_class
  784. padding = mlvl_scores.new_zeros(mlvl_scores.shape[0], 1)
  785. mlvl_scores = torch.cat([mlvl_scores, padding], dim=1)
  786. # multi class NMS
  787. det_bboxes, det_labels = multiclass_nms(mlvl_bboxes, mlvl_scores,
  788. cfg.score_thr, cfg.nms,
  789. cfg.max_per_img)
  790. return det_bboxes, det_labels

No Description

Contributors (3)