You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

reppoints_head.py 35 kB

2 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764
  1. # Copyright (c) OpenMMLab. All rights reserved.
  2. import numpy as np
  3. import torch
  4. import torch.nn as nn
  5. from mmcv.cnn import ConvModule
  6. from mmcv.ops import DeformConv2d
  7. from mmdet.core import (build_assigner, build_sampler, images_to_levels,
  8. multi_apply, unmap)
  9. from mmdet.core.anchor.point_generator import MlvlPointGenerator
  10. from mmdet.core.utils import filter_scores_and_topk
  11. from ..builder import HEADS, build_loss
  12. from .anchor_free_head import AnchorFreeHead
  13. @HEADS.register_module()
  14. class RepPointsHead(AnchorFreeHead):
  15. """RepPoint head.
  16. Args:
  17. point_feat_channels (int): Number of channels of points features.
  18. gradient_mul (float): The multiplier to gradients from
  19. points refinement and recognition.
  20. point_strides (Iterable): points strides.
  21. point_base_scale (int): bbox scale for assigning labels.
  22. loss_cls (dict): Config of classification loss.
  23. loss_bbox_init (dict): Config of initial points loss.
  24. loss_bbox_refine (dict): Config of points loss in refinement.
  25. use_grid_points (bool): If we use bounding box representation, the
  26. reppoints is represented as grid points on the bounding box.
  27. center_init (bool): Whether to use center point assignment.
  28. transform_method (str): The methods to transform RepPoints to bbox.
  29. init_cfg (dict or list[dict], optional): Initialization config dict.
  30. """ # noqa: W605
  31. def __init__(self,
  32. num_classes,
  33. in_channels,
  34. point_feat_channels=256,
  35. num_points=9,
  36. gradient_mul=0.1,
  37. point_strides=[8, 16, 32, 64, 128],
  38. point_base_scale=4,
  39. loss_cls=dict(
  40. type='FocalLoss',
  41. use_sigmoid=True,
  42. gamma=2.0,
  43. alpha=0.25,
  44. loss_weight=1.0),
  45. loss_bbox_init=dict(
  46. type='SmoothL1Loss', beta=1.0 / 9.0, loss_weight=0.5),
  47. loss_bbox_refine=dict(
  48. type='SmoothL1Loss', beta=1.0 / 9.0, loss_weight=1.0),
  49. use_grid_points=False,
  50. center_init=True,
  51. transform_method='moment',
  52. moment_mul=0.01,
  53. init_cfg=dict(
  54. type='Normal',
  55. layer='Conv2d',
  56. std=0.01,
  57. override=dict(
  58. type='Normal',
  59. name='reppoints_cls_out',
  60. std=0.01,
  61. bias_prob=0.01)),
  62. **kwargs):
  63. self.num_points = num_points
  64. self.point_feat_channels = point_feat_channels
  65. self.use_grid_points = use_grid_points
  66. self.center_init = center_init
  67. # we use deform conv to extract points features
  68. self.dcn_kernel = int(np.sqrt(num_points))
  69. self.dcn_pad = int((self.dcn_kernel - 1) / 2)
  70. assert self.dcn_kernel * self.dcn_kernel == num_points, \
  71. 'The points number should be a square number.'
  72. assert self.dcn_kernel % 2 == 1, \
  73. 'The points number should be an odd square number.'
  74. dcn_base = np.arange(-self.dcn_pad,
  75. self.dcn_pad + 1).astype(np.float64)
  76. dcn_base_y = np.repeat(dcn_base, self.dcn_kernel)
  77. dcn_base_x = np.tile(dcn_base, self.dcn_kernel)
  78. dcn_base_offset = np.stack([dcn_base_y, dcn_base_x], axis=1).reshape(
  79. (-1))
  80. self.dcn_base_offset = torch.tensor(dcn_base_offset).view(1, -1, 1, 1)
  81. super().__init__(
  82. num_classes,
  83. in_channels,
  84. loss_cls=loss_cls,
  85. init_cfg=init_cfg,
  86. **kwargs)
  87. self.gradient_mul = gradient_mul
  88. self.point_base_scale = point_base_scale
  89. self.point_strides = point_strides
  90. self.prior_generator = MlvlPointGenerator(
  91. self.point_strides, offset=0.)
  92. self.sampling = loss_cls['type'] not in ['FocalLoss']
  93. if self.train_cfg:
  94. self.init_assigner = build_assigner(self.train_cfg.init.assigner)
  95. self.refine_assigner = build_assigner(
  96. self.train_cfg.refine.assigner)
  97. # use PseudoSampler when sampling is False
  98. if self.sampling and hasattr(self.train_cfg, 'sampler'):
  99. sampler_cfg = self.train_cfg.sampler
  100. else:
  101. sampler_cfg = dict(type='PseudoSampler')
  102. self.sampler = build_sampler(sampler_cfg, context=self)
  103. self.transform_method = transform_method
  104. if self.transform_method == 'moment':
  105. self.moment_transfer = nn.Parameter(
  106. data=torch.zeros(2), requires_grad=True)
  107. self.moment_mul = moment_mul
  108. self.use_sigmoid_cls = loss_cls.get('use_sigmoid', False)
  109. if self.use_sigmoid_cls:
  110. self.cls_out_channels = self.num_classes
  111. else:
  112. self.cls_out_channels = self.num_classes + 1
  113. self.loss_bbox_init = build_loss(loss_bbox_init)
  114. self.loss_bbox_refine = build_loss(loss_bbox_refine)
  115. def _init_layers(self):
  116. """Initialize layers of the head."""
  117. self.relu = nn.ReLU(inplace=True)
  118. self.cls_convs = nn.ModuleList()
  119. self.reg_convs = nn.ModuleList()
  120. for i in range(self.stacked_convs):
  121. chn = self.in_channels if i == 0 else self.feat_channels
  122. self.cls_convs.append(
  123. ConvModule(
  124. chn,
  125. self.feat_channels,
  126. 3,
  127. stride=1,
  128. padding=1,
  129. conv_cfg=self.conv_cfg,
  130. norm_cfg=self.norm_cfg))
  131. self.reg_convs.append(
  132. ConvModule(
  133. chn,
  134. self.feat_channels,
  135. 3,
  136. stride=1,
  137. padding=1,
  138. conv_cfg=self.conv_cfg,
  139. norm_cfg=self.norm_cfg))
  140. pts_out_dim = 4 if self.use_grid_points else 2 * self.num_points
  141. self.reppoints_cls_conv = DeformConv2d(self.feat_channels,
  142. self.point_feat_channels,
  143. self.dcn_kernel, 1,
  144. self.dcn_pad)
  145. self.reppoints_cls_out = nn.Conv2d(self.point_feat_channels,
  146. self.cls_out_channels, 1, 1, 0)
  147. self.reppoints_pts_init_conv = nn.Conv2d(self.feat_channels,
  148. self.point_feat_channels, 3,
  149. 1, 1)
  150. self.reppoints_pts_init_out = nn.Conv2d(self.point_feat_channels,
  151. pts_out_dim, 1, 1, 0)
  152. self.reppoints_pts_refine_conv = DeformConv2d(self.feat_channels,
  153. self.point_feat_channels,
  154. self.dcn_kernel, 1,
  155. self.dcn_pad)
  156. self.reppoints_pts_refine_out = nn.Conv2d(self.point_feat_channels,
  157. pts_out_dim, 1, 1, 0)
  158. def points2bbox(self, pts, y_first=True):
  159. """Converting the points set into bounding box.
  160. :param pts: the input points sets (fields), each points
  161. set (fields) is represented as 2n scalar.
  162. :param y_first: if y_first=True, the point set is represented as
  163. [y1, x1, y2, x2 ... yn, xn], otherwise the point set is
  164. represented as [x1, y1, x2, y2 ... xn, yn].
  165. :return: each points set is converting to a bbox [x1, y1, x2, y2].
  166. """
  167. pts_reshape = pts.view(pts.shape[0], -1, 2, *pts.shape[2:])
  168. pts_y = pts_reshape[:, :, 0, ...] if y_first else pts_reshape[:, :, 1,
  169. ...]
  170. pts_x = pts_reshape[:, :, 1, ...] if y_first else pts_reshape[:, :, 0,
  171. ...]
  172. if self.transform_method == 'minmax':
  173. bbox_left = pts_x.min(dim=1, keepdim=True)[0]
  174. bbox_right = pts_x.max(dim=1, keepdim=True)[0]
  175. bbox_up = pts_y.min(dim=1, keepdim=True)[0]
  176. bbox_bottom = pts_y.max(dim=1, keepdim=True)[0]
  177. bbox = torch.cat([bbox_left, bbox_up, bbox_right, bbox_bottom],
  178. dim=1)
  179. elif self.transform_method == 'partial_minmax':
  180. pts_y = pts_y[:, :4, ...]
  181. pts_x = pts_x[:, :4, ...]
  182. bbox_left = pts_x.min(dim=1, keepdim=True)[0]
  183. bbox_right = pts_x.max(dim=1, keepdim=True)[0]
  184. bbox_up = pts_y.min(dim=1, keepdim=True)[0]
  185. bbox_bottom = pts_y.max(dim=1, keepdim=True)[0]
  186. bbox = torch.cat([bbox_left, bbox_up, bbox_right, bbox_bottom],
  187. dim=1)
  188. elif self.transform_method == 'moment':
  189. pts_y_mean = pts_y.mean(dim=1, keepdim=True)
  190. pts_x_mean = pts_x.mean(dim=1, keepdim=True)
  191. pts_y_std = torch.std(pts_y - pts_y_mean, dim=1, keepdim=True)
  192. pts_x_std = torch.std(pts_x - pts_x_mean, dim=1, keepdim=True)
  193. moment_transfer = (self.moment_transfer * self.moment_mul) + (
  194. self.moment_transfer.detach() * (1 - self.moment_mul))
  195. moment_width_transfer = moment_transfer[0]
  196. moment_height_transfer = moment_transfer[1]
  197. half_width = pts_x_std * torch.exp(moment_width_transfer)
  198. half_height = pts_y_std * torch.exp(moment_height_transfer)
  199. bbox = torch.cat([
  200. pts_x_mean - half_width, pts_y_mean - half_height,
  201. pts_x_mean + half_width, pts_y_mean + half_height
  202. ],
  203. dim=1)
  204. else:
  205. raise NotImplementedError
  206. return bbox
  207. def gen_grid_from_reg(self, reg, previous_boxes):
  208. """Base on the previous bboxes and regression values, we compute the
  209. regressed bboxes and generate the grids on the bboxes.
  210. :param reg: the regression value to previous bboxes.
  211. :param previous_boxes: previous bboxes.
  212. :return: generate grids on the regressed bboxes.
  213. """
  214. b, _, h, w = reg.shape
  215. bxy = (previous_boxes[:, :2, ...] + previous_boxes[:, 2:, ...]) / 2.
  216. bwh = (previous_boxes[:, 2:, ...] -
  217. previous_boxes[:, :2, ...]).clamp(min=1e-6)
  218. grid_topleft = bxy + bwh * reg[:, :2, ...] - 0.5 * bwh * torch.exp(
  219. reg[:, 2:, ...])
  220. grid_wh = bwh * torch.exp(reg[:, 2:, ...])
  221. grid_left = grid_topleft[:, [0], ...]
  222. grid_top = grid_topleft[:, [1], ...]
  223. grid_width = grid_wh[:, [0], ...]
  224. grid_height = grid_wh[:, [1], ...]
  225. intervel = torch.linspace(0., 1., self.dcn_kernel).view(
  226. 1, self.dcn_kernel, 1, 1).type_as(reg)
  227. grid_x = grid_left + grid_width * intervel
  228. grid_x = grid_x.unsqueeze(1).repeat(1, self.dcn_kernel, 1, 1, 1)
  229. grid_x = grid_x.view(b, -1, h, w)
  230. grid_y = grid_top + grid_height * intervel
  231. grid_y = grid_y.unsqueeze(2).repeat(1, 1, self.dcn_kernel, 1, 1)
  232. grid_y = grid_y.view(b, -1, h, w)
  233. grid_yx = torch.stack([grid_y, grid_x], dim=2)
  234. grid_yx = grid_yx.view(b, -1, h, w)
  235. regressed_bbox = torch.cat([
  236. grid_left, grid_top, grid_left + grid_width, grid_top + grid_height
  237. ], 1)
  238. return grid_yx, regressed_bbox
  239. def forward(self, feats):
  240. return multi_apply(self.forward_single, feats)
  241. def forward_single(self, x):
  242. """Forward feature map of a single FPN level."""
  243. dcn_base_offset = self.dcn_base_offset.type_as(x)
  244. # If we use center_init, the initial reppoints is from center points.
  245. # If we use bounding bbox representation, the initial reppoints is
  246. # from regular grid placed on a pre-defined bbox.
  247. if self.use_grid_points or not self.center_init:
  248. scale = self.point_base_scale / 2
  249. points_init = dcn_base_offset / dcn_base_offset.max() * scale
  250. bbox_init = x.new_tensor([-scale, -scale, scale,
  251. scale]).view(1, 4, 1, 1)
  252. else:
  253. points_init = 0
  254. cls_feat = x
  255. pts_feat = x
  256. for cls_conv in self.cls_convs:
  257. cls_feat = cls_conv(cls_feat)
  258. for reg_conv in self.reg_convs:
  259. pts_feat = reg_conv(pts_feat)
  260. # initialize reppoints
  261. pts_out_init = self.reppoints_pts_init_out(
  262. self.relu(self.reppoints_pts_init_conv(pts_feat)))
  263. if self.use_grid_points:
  264. pts_out_init, bbox_out_init = self.gen_grid_from_reg(
  265. pts_out_init, bbox_init.detach())
  266. else:
  267. pts_out_init = pts_out_init + points_init
  268. # refine and classify reppoints
  269. pts_out_init_grad_mul = (1 - self.gradient_mul) * pts_out_init.detach(
  270. ) + self.gradient_mul * pts_out_init
  271. dcn_offset = pts_out_init_grad_mul - dcn_base_offset
  272. cls_out = self.reppoints_cls_out(
  273. self.relu(self.reppoints_cls_conv(cls_feat, dcn_offset)))
  274. pts_out_refine = self.reppoints_pts_refine_out(
  275. self.relu(self.reppoints_pts_refine_conv(pts_feat, dcn_offset)))
  276. if self.use_grid_points:
  277. pts_out_refine, bbox_out_refine = self.gen_grid_from_reg(
  278. pts_out_refine, bbox_out_init.detach())
  279. else:
  280. pts_out_refine = pts_out_refine + pts_out_init.detach()
  281. if self.training:
  282. return cls_out, pts_out_init, pts_out_refine
  283. else:
  284. return cls_out, self.points2bbox(pts_out_refine)
  285. def get_points(self, featmap_sizes, img_metas, device):
  286. """Get points according to feature map sizes.
  287. Args:
  288. featmap_sizes (list[tuple]): Multi-level feature map sizes.
  289. img_metas (list[dict]): Image meta info.
  290. Returns:
  291. tuple: points of each image, valid flags of each image
  292. """
  293. num_imgs = len(img_metas)
  294. # since feature map sizes of all images are the same, we only compute
  295. # points center for one time
  296. multi_level_points = self.prior_generator.grid_priors(
  297. featmap_sizes, device=device, with_stride=True)
  298. points_list = [[point.clone() for point in multi_level_points]
  299. for _ in range(num_imgs)]
  300. # for each image, we compute valid flags of multi level grids
  301. valid_flag_list = []
  302. for img_id, img_meta in enumerate(img_metas):
  303. multi_level_flags = self.prior_generator.valid_flags(
  304. featmap_sizes, img_meta['pad_shape'])
  305. valid_flag_list.append(multi_level_flags)
  306. return points_list, valid_flag_list
  307. def centers_to_bboxes(self, point_list):
  308. """Get bboxes according to center points.
  309. Only used in :class:`MaxIoUAssigner`.
  310. """
  311. bbox_list = []
  312. for i_img, point in enumerate(point_list):
  313. bbox = []
  314. for i_lvl in range(len(self.point_strides)):
  315. scale = self.point_base_scale * self.point_strides[i_lvl] * 0.5
  316. bbox_shift = torch.Tensor([-scale, -scale, scale,
  317. scale]).view(1, 4).type_as(point[0])
  318. bbox_center = torch.cat(
  319. [point[i_lvl][:, :2], point[i_lvl][:, :2]], dim=1)
  320. bbox.append(bbox_center + bbox_shift)
  321. bbox_list.append(bbox)
  322. return bbox_list
  323. def offset_to_pts(self, center_list, pred_list):
  324. """Change from point offset to point coordinate."""
  325. pts_list = []
  326. for i_lvl in range(len(self.point_strides)):
  327. pts_lvl = []
  328. for i_img in range(len(center_list)):
  329. pts_center = center_list[i_img][i_lvl][:, :2].repeat(
  330. 1, self.num_points)
  331. pts_shift = pred_list[i_lvl][i_img]
  332. yx_pts_shift = pts_shift.permute(1, 2, 0).view(
  333. -1, 2 * self.num_points)
  334. y_pts_shift = yx_pts_shift[..., 0::2]
  335. x_pts_shift = yx_pts_shift[..., 1::2]
  336. xy_pts_shift = torch.stack([x_pts_shift, y_pts_shift], -1)
  337. xy_pts_shift = xy_pts_shift.view(*yx_pts_shift.shape[:-1], -1)
  338. pts = xy_pts_shift * self.point_strides[i_lvl] + pts_center
  339. pts_lvl.append(pts)
  340. pts_lvl = torch.stack(pts_lvl, 0)
  341. pts_list.append(pts_lvl)
  342. return pts_list
  343. def _point_target_single(self,
  344. flat_proposals,
  345. valid_flags,
  346. gt_bboxes,
  347. gt_bboxes_ignore,
  348. gt_labels,
  349. stage='init',
  350. unmap_outputs=True):
  351. inside_flags = valid_flags
  352. if not inside_flags.any():
  353. return (None, ) * 7
  354. # assign gt and sample proposals
  355. proposals = flat_proposals[inside_flags, :]
  356. if stage == 'init':
  357. assigner = self.init_assigner
  358. pos_weight = self.train_cfg.init.pos_weight
  359. else:
  360. assigner = self.refine_assigner
  361. pos_weight = self.train_cfg.refine.pos_weight
  362. assign_result = assigner.assign(proposals, gt_bboxes, gt_bboxes_ignore,
  363. None if self.sampling else gt_labels)
  364. sampling_result = self.sampler.sample(assign_result, proposals,
  365. gt_bboxes)
  366. num_valid_proposals = proposals.shape[0]
  367. bbox_gt = proposals.new_zeros([num_valid_proposals, 4])
  368. pos_proposals = torch.zeros_like(proposals)
  369. proposals_weights = proposals.new_zeros([num_valid_proposals, 4])
  370. labels = proposals.new_full((num_valid_proposals, ),
  371. self.num_classes,
  372. dtype=torch.long)
  373. label_weights = proposals.new_zeros(
  374. num_valid_proposals, dtype=torch.float)
  375. pos_inds = sampling_result.pos_inds
  376. neg_inds = sampling_result.neg_inds
  377. if len(pos_inds) > 0:
  378. pos_gt_bboxes = sampling_result.pos_gt_bboxes
  379. bbox_gt[pos_inds, :] = pos_gt_bboxes
  380. pos_proposals[pos_inds, :] = proposals[pos_inds, :]
  381. proposals_weights[pos_inds, :] = 1.0
  382. if gt_labels is None:
  383. # Only rpn gives gt_labels as None
  384. # Foreground is the first class
  385. labels[pos_inds] = 0
  386. else:
  387. labels[pos_inds] = gt_labels[
  388. sampling_result.pos_assigned_gt_inds]
  389. if pos_weight <= 0:
  390. label_weights[pos_inds] = 1.0
  391. else:
  392. label_weights[pos_inds] = pos_weight
  393. if len(neg_inds) > 0:
  394. label_weights[neg_inds] = 1.0
  395. # map up to original set of proposals
  396. if unmap_outputs:
  397. num_total_proposals = flat_proposals.size(0)
  398. labels = unmap(labels, num_total_proposals, inside_flags)
  399. label_weights = unmap(label_weights, num_total_proposals,
  400. inside_flags)
  401. bbox_gt = unmap(bbox_gt, num_total_proposals, inside_flags)
  402. pos_proposals = unmap(pos_proposals, num_total_proposals,
  403. inside_flags)
  404. proposals_weights = unmap(proposals_weights, num_total_proposals,
  405. inside_flags)
  406. return (labels, label_weights, bbox_gt, pos_proposals,
  407. proposals_weights, pos_inds, neg_inds)
  408. def get_targets(self,
  409. proposals_list,
  410. valid_flag_list,
  411. gt_bboxes_list,
  412. img_metas,
  413. gt_bboxes_ignore_list=None,
  414. gt_labels_list=None,
  415. stage='init',
  416. label_channels=1,
  417. unmap_outputs=True):
  418. """Compute corresponding GT box and classification targets for
  419. proposals.
  420. Args:
  421. proposals_list (list[list]): Multi level points/bboxes of each
  422. image.
  423. valid_flag_list (list[list]): Multi level valid flags of each
  424. image.
  425. gt_bboxes_list (list[Tensor]): Ground truth bboxes of each image.
  426. img_metas (list[dict]): Meta info of each image.
  427. gt_bboxes_ignore_list (list[Tensor]): Ground truth bboxes to be
  428. ignored.
  429. gt_bboxes_list (list[Tensor]): Ground truth labels of each box.
  430. stage (str): `init` or `refine`. Generate target for init stage or
  431. refine stage
  432. label_channels (int): Channel of label.
  433. unmap_outputs (bool): Whether to map outputs back to the original
  434. set of anchors.
  435. Returns:
  436. tuple:
  437. - labels_list (list[Tensor]): Labels of each level.
  438. - label_weights_list (list[Tensor]): Label weights of each level. # noqa: E501
  439. - bbox_gt_list (list[Tensor]): Ground truth bbox of each level.
  440. - proposal_list (list[Tensor]): Proposals(points/bboxes) of each level. # noqa: E501
  441. - proposal_weights_list (list[Tensor]): Proposal weights of each level. # noqa: E501
  442. - num_total_pos (int): Number of positive samples in all images. # noqa: E501
  443. - num_total_neg (int): Number of negative samples in all images. # noqa: E501
  444. """
  445. assert stage in ['init', 'refine']
  446. num_imgs = len(img_metas)
  447. assert len(proposals_list) == len(valid_flag_list) == num_imgs
  448. # points number of multi levels
  449. num_level_proposals = [points.size(0) for points in proposals_list[0]]
  450. # concat all level points and flags to a single tensor
  451. for i in range(num_imgs):
  452. assert len(proposals_list[i]) == len(valid_flag_list[i])
  453. proposals_list[i] = torch.cat(proposals_list[i])
  454. valid_flag_list[i] = torch.cat(valid_flag_list[i])
  455. # compute targets for each image
  456. if gt_bboxes_ignore_list is None:
  457. gt_bboxes_ignore_list = [None for _ in range(num_imgs)]
  458. if gt_labels_list is None:
  459. gt_labels_list = [None for _ in range(num_imgs)]
  460. (all_labels, all_label_weights, all_bbox_gt, all_proposals,
  461. all_proposal_weights, pos_inds_list, neg_inds_list) = multi_apply(
  462. self._point_target_single,
  463. proposals_list,
  464. valid_flag_list,
  465. gt_bboxes_list,
  466. gt_bboxes_ignore_list,
  467. gt_labels_list,
  468. stage=stage,
  469. unmap_outputs=unmap_outputs)
  470. # no valid points
  471. if any([labels is None for labels in all_labels]):
  472. return None
  473. # sampled points of all images
  474. num_total_pos = sum([max(inds.numel(), 1) for inds in pos_inds_list])
  475. num_total_neg = sum([max(inds.numel(), 1) for inds in neg_inds_list])
  476. labels_list = images_to_levels(all_labels, num_level_proposals)
  477. label_weights_list = images_to_levels(all_label_weights,
  478. num_level_proposals)
  479. bbox_gt_list = images_to_levels(all_bbox_gt, num_level_proposals)
  480. proposals_list = images_to_levels(all_proposals, num_level_proposals)
  481. proposal_weights_list = images_to_levels(all_proposal_weights,
  482. num_level_proposals)
  483. return (labels_list, label_weights_list, bbox_gt_list, proposals_list,
  484. proposal_weights_list, num_total_pos, num_total_neg)
  485. def loss_single(self, cls_score, pts_pred_init, pts_pred_refine, labels,
  486. label_weights, bbox_gt_init, bbox_weights_init,
  487. bbox_gt_refine, bbox_weights_refine, stride,
  488. num_total_samples_init, num_total_samples_refine):
  489. # classification loss
  490. labels = labels.reshape(-1)
  491. label_weights = label_weights.reshape(-1)
  492. cls_score = cls_score.permute(0, 2, 3,
  493. 1).reshape(-1, self.cls_out_channels)
  494. cls_score = cls_score.contiguous()
  495. loss_cls = self.loss_cls(
  496. cls_score,
  497. labels,
  498. label_weights,
  499. avg_factor=num_total_samples_refine)
  500. # points loss
  501. bbox_gt_init = bbox_gt_init.reshape(-1, 4)
  502. bbox_weights_init = bbox_weights_init.reshape(-1, 4)
  503. bbox_pred_init = self.points2bbox(
  504. pts_pred_init.reshape(-1, 2 * self.num_points), y_first=False)
  505. bbox_gt_refine = bbox_gt_refine.reshape(-1, 4)
  506. bbox_weights_refine = bbox_weights_refine.reshape(-1, 4)
  507. bbox_pred_refine = self.points2bbox(
  508. pts_pred_refine.reshape(-1, 2 * self.num_points), y_first=False)
  509. normalize_term = self.point_base_scale * stride
  510. loss_pts_init = self.loss_bbox_init(
  511. bbox_pred_init / normalize_term,
  512. bbox_gt_init / normalize_term,
  513. bbox_weights_init,
  514. avg_factor=num_total_samples_init)
  515. loss_pts_refine = self.loss_bbox_refine(
  516. bbox_pred_refine / normalize_term,
  517. bbox_gt_refine / normalize_term,
  518. bbox_weights_refine,
  519. avg_factor=num_total_samples_refine)
  520. return loss_cls, loss_pts_init, loss_pts_refine
  521. def loss(self,
  522. cls_scores,
  523. pts_preds_init,
  524. pts_preds_refine,
  525. gt_bboxes,
  526. gt_labels,
  527. img_metas,
  528. gt_bboxes_ignore=None):
  529. featmap_sizes = [featmap.size()[-2:] for featmap in cls_scores]
  530. device = cls_scores[0].device
  531. label_channels = self.cls_out_channels if self.use_sigmoid_cls else 1
  532. # target for initial stage
  533. center_list, valid_flag_list = self.get_points(featmap_sizes,
  534. img_metas, device)
  535. pts_coordinate_preds_init = self.offset_to_pts(center_list,
  536. pts_preds_init)
  537. if self.train_cfg.init.assigner['type'] == 'PointAssigner':
  538. # Assign target for center list
  539. candidate_list = center_list
  540. else:
  541. # transform center list to bbox list and
  542. # assign target for bbox list
  543. bbox_list = self.centers_to_bboxes(center_list)
  544. candidate_list = bbox_list
  545. cls_reg_targets_init = self.get_targets(
  546. candidate_list,
  547. valid_flag_list,
  548. gt_bboxes,
  549. img_metas,
  550. gt_bboxes_ignore_list=gt_bboxes_ignore,
  551. gt_labels_list=gt_labels,
  552. stage='init',
  553. label_channels=label_channels)
  554. (*_, bbox_gt_list_init, candidate_list_init, bbox_weights_list_init,
  555. num_total_pos_init, num_total_neg_init) = cls_reg_targets_init
  556. num_total_samples_init = (
  557. num_total_pos_init +
  558. num_total_neg_init if self.sampling else num_total_pos_init)
  559. # target for refinement stage
  560. center_list, valid_flag_list = self.get_points(featmap_sizes,
  561. img_metas, device)
  562. pts_coordinate_preds_refine = self.offset_to_pts(
  563. center_list, pts_preds_refine)
  564. bbox_list = []
  565. for i_img, center in enumerate(center_list):
  566. bbox = []
  567. for i_lvl in range(len(pts_preds_refine)):
  568. bbox_preds_init = self.points2bbox(
  569. pts_preds_init[i_lvl].detach())
  570. bbox_shift = bbox_preds_init * self.point_strides[i_lvl]
  571. bbox_center = torch.cat(
  572. [center[i_lvl][:, :2], center[i_lvl][:, :2]], dim=1)
  573. bbox.append(bbox_center +
  574. bbox_shift[i_img].permute(1, 2, 0).reshape(-1, 4))
  575. bbox_list.append(bbox)
  576. cls_reg_targets_refine = self.get_targets(
  577. bbox_list,
  578. valid_flag_list,
  579. gt_bboxes,
  580. img_metas,
  581. gt_bboxes_ignore_list=gt_bboxes_ignore,
  582. gt_labels_list=gt_labels,
  583. stage='refine',
  584. label_channels=label_channels)
  585. (labels_list, label_weights_list, bbox_gt_list_refine,
  586. candidate_list_refine, bbox_weights_list_refine, num_total_pos_refine,
  587. num_total_neg_refine) = cls_reg_targets_refine
  588. num_total_samples_refine = (
  589. num_total_pos_refine +
  590. num_total_neg_refine if self.sampling else num_total_pos_refine)
  591. # compute loss
  592. losses_cls, losses_pts_init, losses_pts_refine = multi_apply(
  593. self.loss_single,
  594. cls_scores,
  595. pts_coordinate_preds_init,
  596. pts_coordinate_preds_refine,
  597. labels_list,
  598. label_weights_list,
  599. bbox_gt_list_init,
  600. bbox_weights_list_init,
  601. bbox_gt_list_refine,
  602. bbox_weights_list_refine,
  603. self.point_strides,
  604. num_total_samples_init=num_total_samples_init,
  605. num_total_samples_refine=num_total_samples_refine)
  606. loss_dict_all = {
  607. 'loss_cls': losses_cls,
  608. 'loss_pts_init': losses_pts_init,
  609. 'loss_pts_refine': losses_pts_refine
  610. }
  611. return loss_dict_all
  612. # Same as base_dense_head/_get_bboxes_single except self._bbox_decode
  613. def _get_bboxes_single(self,
  614. cls_score_list,
  615. bbox_pred_list,
  616. score_factor_list,
  617. mlvl_priors,
  618. img_meta,
  619. cfg,
  620. rescale=False,
  621. with_nms=True,
  622. **kwargs):
  623. """Transform outputs of a single image into bbox predictions.
  624. Args:
  625. cls_score_list (list[Tensor]): Box scores from all scale
  626. levels of a single image, each item has shape
  627. (num_priors * num_classes, H, W).
  628. bbox_pred_list (list[Tensor]): Box energies / deltas from
  629. all scale levels of a single image, each item has shape
  630. (num_priors * 4, H, W).
  631. score_factor_list (list[Tensor]): Score factor from all scale
  632. levels of a single image. RepPoints head does not need
  633. this value.
  634. mlvl_priors (list[Tensor]): Each element in the list is
  635. the priors of a single level in feature pyramid, has shape
  636. (num_priors, 2).
  637. img_meta (dict): Image meta info.
  638. cfg (mmcv.Config): Test / postprocessing configuration,
  639. if None, test_cfg would be used.
  640. rescale (bool): If True, return boxes in original image space.
  641. Default: False.
  642. with_nms (bool): If True, do nms before return boxes.
  643. Default: True.
  644. Returns:
  645. tuple[Tensor]: Results of detected bboxes and labels. If with_nms
  646. is False and mlvl_score_factor is None, return mlvl_bboxes and
  647. mlvl_scores, else return mlvl_bboxes, mlvl_scores and
  648. mlvl_score_factor. Usually with_nms is False is used for aug
  649. test. If with_nms is True, then return the following format
  650. - det_bboxes (Tensor): Predicted bboxes with shape \
  651. [num_bboxes, 5], where the first 4 columns are bounding \
  652. box positions (tl_x, tl_y, br_x, br_y) and the 5-th \
  653. column are scores between 0 and 1.
  654. - det_labels (Tensor): Predicted labels of the corresponding \
  655. box with shape [num_bboxes].
  656. """
  657. cfg = self.test_cfg if cfg is None else cfg
  658. assert len(cls_score_list) == len(bbox_pred_list)
  659. img_shape = img_meta['img_shape']
  660. nms_pre = cfg.get('nms_pre', -1)
  661. mlvl_bboxes = []
  662. mlvl_scores = []
  663. mlvl_labels = []
  664. for level_idx, (cls_score, bbox_pred, priors) in enumerate(
  665. zip(cls_score_list, bbox_pred_list, mlvl_priors)):
  666. assert cls_score.size()[-2:] == bbox_pred.size()[-2:]
  667. bbox_pred = bbox_pred.permute(1, 2, 0).reshape(-1, 4)
  668. cls_score = cls_score.permute(1, 2,
  669. 0).reshape(-1, self.cls_out_channels)
  670. if self.use_sigmoid_cls:
  671. scores = cls_score.sigmoid()
  672. else:
  673. scores = cls_score.softmax(-1)[:, :-1]
  674. # After https://github.com/open-mmlab/mmdetection/pull/6268/,
  675. # this operation keeps fewer bboxes under the same `nms_pre`.
  676. # There is no difference in performance for most models. If you
  677. # find a slight drop in performance, you can set a larger
  678. # `nms_pre` than before.
  679. results = filter_scores_and_topk(
  680. scores, cfg.score_thr, nms_pre,
  681. dict(bbox_pred=bbox_pred, priors=priors))
  682. scores, labels, _, filtered_results = results
  683. bbox_pred = filtered_results['bbox_pred']
  684. priors = filtered_results['priors']
  685. bboxes = self._bbox_decode(priors, bbox_pred,
  686. self.point_strides[level_idx],
  687. img_shape)
  688. mlvl_bboxes.append(bboxes)
  689. mlvl_scores.append(scores)
  690. mlvl_labels.append(labels)
  691. return self._bbox_post_process(
  692. mlvl_scores,
  693. mlvl_labels,
  694. mlvl_bboxes,
  695. img_meta['scale_factor'],
  696. cfg,
  697. rescale=rescale,
  698. with_nms=with_nms)
  699. def _bbox_decode(self, points, bbox_pred, stride, max_shape):
  700. bbox_pos_center = torch.cat([points[:, :2], points[:, :2]], dim=1)
  701. bboxes = bbox_pred * stride + bbox_pos_center
  702. x1 = bboxes[:, 0].clamp(min=0, max=max_shape[1])
  703. y1 = bboxes[:, 1].clamp(min=0, max=max_shape[0])
  704. x2 = bboxes[:, 2].clamp(min=0, max=max_shape[1])
  705. y2 = bboxes[:, 3].clamp(min=0, max=max_shape[0])
  706. decoded_bboxes = torch.stack([x1, y1, x2, y2], dim=-1)
  707. return decoded_bboxes

No Description

Contributors (3)