You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

centripetal_head.py 20 kB

2 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427
  1. # Copyright (c) OpenMMLab. All rights reserved.
  2. import torch.nn as nn
  3. from mmcv.cnn import ConvModule, normal_init
  4. from mmcv.ops import DeformConv2d
  5. from mmdet.core import multi_apply
  6. from ..builder import HEADS, build_loss
  7. from .corner_head import CornerHead
  8. @HEADS.register_module()
  9. class CentripetalHead(CornerHead):
  10. """Head of CentripetalNet: Pursuing High-quality Keypoint Pairs for Object
  11. Detection.
  12. CentripetalHead inherits from :class:`CornerHead`. It removes the
  13. embedding branch and adds guiding shift and centripetal shift branches.
  14. More details can be found in the `paper
  15. <https://arxiv.org/abs/2003.09119>`_ .
  16. Args:
  17. num_classes (int): Number of categories excluding the background
  18. category.
  19. in_channels (int): Number of channels in the input feature map.
  20. num_feat_levels (int): Levels of feature from the previous module. 2
  21. for HourglassNet-104 and 1 for HourglassNet-52. HourglassNet-104
  22. outputs the final feature and intermediate supervision feature and
  23. HourglassNet-52 only outputs the final feature. Default: 2.
  24. corner_emb_channels (int): Channel of embedding vector. Default: 1.
  25. train_cfg (dict | None): Training config. Useless in CornerHead,
  26. but we keep this variable for SingleStageDetector. Default: None.
  27. test_cfg (dict | None): Testing config of CornerHead. Default: None.
  28. loss_heatmap (dict | None): Config of corner heatmap loss. Default:
  29. GaussianFocalLoss.
  30. loss_embedding (dict | None): Config of corner embedding loss. Default:
  31. AssociativeEmbeddingLoss.
  32. loss_offset (dict | None): Config of corner offset loss. Default:
  33. SmoothL1Loss.
  34. loss_guiding_shift (dict): Config of guiding shift loss. Default:
  35. SmoothL1Loss.
  36. loss_centripetal_shift (dict): Config of centripetal shift loss.
  37. Default: SmoothL1Loss.
  38. init_cfg (dict or list[dict], optional): Initialization config dict.
  39. Default: None
  40. """
  41. def __init__(self,
  42. *args,
  43. centripetal_shift_channels=2,
  44. guiding_shift_channels=2,
  45. feat_adaption_conv_kernel=3,
  46. loss_guiding_shift=dict(
  47. type='SmoothL1Loss', beta=1.0, loss_weight=0.05),
  48. loss_centripetal_shift=dict(
  49. type='SmoothL1Loss', beta=1.0, loss_weight=1),
  50. init_cfg=None,
  51. **kwargs):
  52. assert init_cfg is None, 'To prevent abnormal initialization ' \
  53. 'behavior, init_cfg is not allowed to be set'
  54. assert centripetal_shift_channels == 2, (
  55. 'CentripetalHead only support centripetal_shift_channels == 2')
  56. self.centripetal_shift_channels = centripetal_shift_channels
  57. assert guiding_shift_channels == 2, (
  58. 'CentripetalHead only support guiding_shift_channels == 2')
  59. self.guiding_shift_channels = guiding_shift_channels
  60. self.feat_adaption_conv_kernel = feat_adaption_conv_kernel
  61. super(CentripetalHead, self).__init__(
  62. *args, init_cfg=init_cfg, **kwargs)
  63. self.loss_guiding_shift = build_loss(loss_guiding_shift)
  64. self.loss_centripetal_shift = build_loss(loss_centripetal_shift)
  65. def _init_centripetal_layers(self):
  66. """Initialize centripetal layers.
  67. Including feature adaption deform convs (feat_adaption), deform offset
  68. prediction convs (dcn_off), guiding shift (guiding_shift) and
  69. centripetal shift ( centripetal_shift). Each branch has two parts:
  70. prefix `tl_` for top-left and `br_` for bottom-right.
  71. """
  72. self.tl_feat_adaption = nn.ModuleList()
  73. self.br_feat_adaption = nn.ModuleList()
  74. self.tl_dcn_offset = nn.ModuleList()
  75. self.br_dcn_offset = nn.ModuleList()
  76. self.tl_guiding_shift = nn.ModuleList()
  77. self.br_guiding_shift = nn.ModuleList()
  78. self.tl_centripetal_shift = nn.ModuleList()
  79. self.br_centripetal_shift = nn.ModuleList()
  80. for _ in range(self.num_feat_levels):
  81. self.tl_feat_adaption.append(
  82. DeformConv2d(self.in_channels, self.in_channels,
  83. self.feat_adaption_conv_kernel, 1, 1))
  84. self.br_feat_adaption.append(
  85. DeformConv2d(self.in_channels, self.in_channels,
  86. self.feat_adaption_conv_kernel, 1, 1))
  87. self.tl_guiding_shift.append(
  88. self._make_layers(
  89. out_channels=self.guiding_shift_channels,
  90. in_channels=self.in_channels))
  91. self.br_guiding_shift.append(
  92. self._make_layers(
  93. out_channels=self.guiding_shift_channels,
  94. in_channels=self.in_channels))
  95. self.tl_dcn_offset.append(
  96. ConvModule(
  97. self.guiding_shift_channels,
  98. self.feat_adaption_conv_kernel**2 *
  99. self.guiding_shift_channels,
  100. 1,
  101. bias=False,
  102. act_cfg=None))
  103. self.br_dcn_offset.append(
  104. ConvModule(
  105. self.guiding_shift_channels,
  106. self.feat_adaption_conv_kernel**2 *
  107. self.guiding_shift_channels,
  108. 1,
  109. bias=False,
  110. act_cfg=None))
  111. self.tl_centripetal_shift.append(
  112. self._make_layers(
  113. out_channels=self.centripetal_shift_channels,
  114. in_channels=self.in_channels))
  115. self.br_centripetal_shift.append(
  116. self._make_layers(
  117. out_channels=self.centripetal_shift_channels,
  118. in_channels=self.in_channels))
  119. def _init_layers(self):
  120. """Initialize layers for CentripetalHead.
  121. Including two parts: CornerHead layers and CentripetalHead layers
  122. """
  123. super()._init_layers() # using _init_layers in CornerHead
  124. self._init_centripetal_layers()
  125. def init_weights(self):
  126. super(CentripetalHead, self).init_weights()
  127. for i in range(self.num_feat_levels):
  128. normal_init(self.tl_feat_adaption[i], std=0.01)
  129. normal_init(self.br_feat_adaption[i], std=0.01)
  130. normal_init(self.tl_dcn_offset[i].conv, std=0.1)
  131. normal_init(self.br_dcn_offset[i].conv, std=0.1)
  132. _ = [x.conv.reset_parameters() for x in self.tl_guiding_shift[i]]
  133. _ = [x.conv.reset_parameters() for x in self.br_guiding_shift[i]]
  134. _ = [
  135. x.conv.reset_parameters() for x in self.tl_centripetal_shift[i]
  136. ]
  137. _ = [
  138. x.conv.reset_parameters() for x in self.br_centripetal_shift[i]
  139. ]
  140. def forward_single(self, x, lvl_ind):
  141. """Forward feature of a single level.
  142. Args:
  143. x (Tensor): Feature of a single level.
  144. lvl_ind (int): Level index of current feature.
  145. Returns:
  146. tuple[Tensor]: A tuple of CentripetalHead's output for current
  147. feature level. Containing the following Tensors:
  148. - tl_heat (Tensor): Predicted top-left corner heatmap.
  149. - br_heat (Tensor): Predicted bottom-right corner heatmap.
  150. - tl_off (Tensor): Predicted top-left offset heatmap.
  151. - br_off (Tensor): Predicted bottom-right offset heatmap.
  152. - tl_guiding_shift (Tensor): Predicted top-left guiding shift
  153. heatmap.
  154. - br_guiding_shift (Tensor): Predicted bottom-right guiding
  155. shift heatmap.
  156. - tl_centripetal_shift (Tensor): Predicted top-left centripetal
  157. shift heatmap.
  158. - br_centripetal_shift (Tensor): Predicted bottom-right
  159. centripetal shift heatmap.
  160. """
  161. tl_heat, br_heat, _, _, tl_off, br_off, tl_pool, br_pool = super(
  162. ).forward_single(
  163. x, lvl_ind, return_pool=True)
  164. tl_guiding_shift = self.tl_guiding_shift[lvl_ind](tl_pool)
  165. br_guiding_shift = self.br_guiding_shift[lvl_ind](br_pool)
  166. tl_dcn_offset = self.tl_dcn_offset[lvl_ind](tl_guiding_shift.detach())
  167. br_dcn_offset = self.br_dcn_offset[lvl_ind](br_guiding_shift.detach())
  168. tl_feat_adaption = self.tl_feat_adaption[lvl_ind](tl_pool,
  169. tl_dcn_offset)
  170. br_feat_adaption = self.br_feat_adaption[lvl_ind](br_pool,
  171. br_dcn_offset)
  172. tl_centripetal_shift = self.tl_centripetal_shift[lvl_ind](
  173. tl_feat_adaption)
  174. br_centripetal_shift = self.br_centripetal_shift[lvl_ind](
  175. br_feat_adaption)
  176. result_list = [
  177. tl_heat, br_heat, tl_off, br_off, tl_guiding_shift,
  178. br_guiding_shift, tl_centripetal_shift, br_centripetal_shift
  179. ]
  180. return result_list
  181. def loss(self,
  182. tl_heats,
  183. br_heats,
  184. tl_offs,
  185. br_offs,
  186. tl_guiding_shifts,
  187. br_guiding_shifts,
  188. tl_centripetal_shifts,
  189. br_centripetal_shifts,
  190. gt_bboxes,
  191. gt_labels,
  192. img_metas,
  193. gt_bboxes_ignore=None):
  194. """Compute losses of the head.
  195. Args:
  196. tl_heats (list[Tensor]): Top-left corner heatmaps for each level
  197. with shape (N, num_classes, H, W).
  198. br_heats (list[Tensor]): Bottom-right corner heatmaps for each
  199. level with shape (N, num_classes, H, W).
  200. tl_offs (list[Tensor]): Top-left corner offsets for each level
  201. with shape (N, corner_offset_channels, H, W).
  202. br_offs (list[Tensor]): Bottom-right corner offsets for each level
  203. with shape (N, corner_offset_channels, H, W).
  204. tl_guiding_shifts (list[Tensor]): Top-left guiding shifts for each
  205. level with shape (N, guiding_shift_channels, H, W).
  206. br_guiding_shifts (list[Tensor]): Bottom-right guiding shifts for
  207. each level with shape (N, guiding_shift_channels, H, W).
  208. tl_centripetal_shifts (list[Tensor]): Top-left centripetal shifts
  209. for each level with shape (N, centripetal_shift_channels, H,
  210. W).
  211. br_centripetal_shifts (list[Tensor]): Bottom-right centripetal
  212. shifts for each level with shape (N,
  213. centripetal_shift_channels, H, W).
  214. gt_bboxes (list[Tensor]): Ground truth bboxes for each image with
  215. shape (num_gts, 4) in [left, top, right, bottom] format.
  216. gt_labels (list[Tensor]): Class indices corresponding to each box.
  217. img_metas (list[dict]): Meta information of each image, e.g.,
  218. image size, scaling factor, etc.
  219. gt_bboxes_ignore (list[Tensor] | None): Specify which bounding
  220. boxes can be ignored when computing the loss.
  221. Returns:
  222. dict[str, Tensor]: A dictionary of loss components. Containing the
  223. following losses:
  224. - det_loss (list[Tensor]): Corner keypoint losses of all
  225. feature levels.
  226. - off_loss (list[Tensor]): Corner offset losses of all feature
  227. levels.
  228. - guiding_loss (list[Tensor]): Guiding shift losses of all
  229. feature levels.
  230. - centripetal_loss (list[Tensor]): Centripetal shift losses of
  231. all feature levels.
  232. """
  233. targets = self.get_targets(
  234. gt_bboxes,
  235. gt_labels,
  236. tl_heats[-1].shape,
  237. img_metas[0]['pad_shape'],
  238. with_corner_emb=self.with_corner_emb,
  239. with_guiding_shift=True,
  240. with_centripetal_shift=True)
  241. mlvl_targets = [targets for _ in range(self.num_feat_levels)]
  242. [det_losses, off_losses, guiding_losses, centripetal_losses
  243. ] = multi_apply(self.loss_single, tl_heats, br_heats, tl_offs,
  244. br_offs, tl_guiding_shifts, br_guiding_shifts,
  245. tl_centripetal_shifts, br_centripetal_shifts,
  246. mlvl_targets)
  247. loss_dict = dict(
  248. det_loss=det_losses,
  249. off_loss=off_losses,
  250. guiding_loss=guiding_losses,
  251. centripetal_loss=centripetal_losses)
  252. return loss_dict
  253. def loss_single(self, tl_hmp, br_hmp, tl_off, br_off, tl_guiding_shift,
  254. br_guiding_shift, tl_centripetal_shift,
  255. br_centripetal_shift, targets):
  256. """Compute losses for single level.
  257. Args:
  258. tl_hmp (Tensor): Top-left corner heatmap for current level with
  259. shape (N, num_classes, H, W).
  260. br_hmp (Tensor): Bottom-right corner heatmap for current level with
  261. shape (N, num_classes, H, W).
  262. tl_off (Tensor): Top-left corner offset for current level with
  263. shape (N, corner_offset_channels, H, W).
  264. br_off (Tensor): Bottom-right corner offset for current level with
  265. shape (N, corner_offset_channels, H, W).
  266. tl_guiding_shift (Tensor): Top-left guiding shift for current level
  267. with shape (N, guiding_shift_channels, H, W).
  268. br_guiding_shift (Tensor): Bottom-right guiding shift for current
  269. level with shape (N, guiding_shift_channels, H, W).
  270. tl_centripetal_shift (Tensor): Top-left centripetal shift for
  271. current level with shape (N, centripetal_shift_channels, H, W).
  272. br_centripetal_shift (Tensor): Bottom-right centripetal shift for
  273. current level with shape (N, centripetal_shift_channels, H, W).
  274. targets (dict): Corner target generated by `get_targets`.
  275. Returns:
  276. tuple[torch.Tensor]: Losses of the head's different branches
  277. containing the following losses:
  278. - det_loss (Tensor): Corner keypoint loss.
  279. - off_loss (Tensor): Corner offset loss.
  280. - guiding_loss (Tensor): Guiding shift loss.
  281. - centripetal_loss (Tensor): Centripetal shift loss.
  282. """
  283. targets['corner_embedding'] = None
  284. det_loss, _, _, off_loss = super().loss_single(tl_hmp, br_hmp, None,
  285. None, tl_off, br_off,
  286. targets)
  287. gt_tl_guiding_shift = targets['topleft_guiding_shift']
  288. gt_br_guiding_shift = targets['bottomright_guiding_shift']
  289. gt_tl_centripetal_shift = targets['topleft_centripetal_shift']
  290. gt_br_centripetal_shift = targets['bottomright_centripetal_shift']
  291. gt_tl_heatmap = targets['topleft_heatmap']
  292. gt_br_heatmap = targets['bottomright_heatmap']
  293. # We only compute the offset loss at the real corner position.
  294. # The value of real corner would be 1 in heatmap ground truth.
  295. # The mask is computed in class agnostic mode and its shape is
  296. # batch * 1 * width * height.
  297. tl_mask = gt_tl_heatmap.eq(1).sum(1).gt(0).unsqueeze(1).type_as(
  298. gt_tl_heatmap)
  299. br_mask = gt_br_heatmap.eq(1).sum(1).gt(0).unsqueeze(1).type_as(
  300. gt_br_heatmap)
  301. # Guiding shift loss
  302. tl_guiding_loss = self.loss_guiding_shift(
  303. tl_guiding_shift,
  304. gt_tl_guiding_shift,
  305. tl_mask,
  306. avg_factor=tl_mask.sum())
  307. br_guiding_loss = self.loss_guiding_shift(
  308. br_guiding_shift,
  309. gt_br_guiding_shift,
  310. br_mask,
  311. avg_factor=br_mask.sum())
  312. guiding_loss = (tl_guiding_loss + br_guiding_loss) / 2.0
  313. # Centripetal shift loss
  314. tl_centripetal_loss = self.loss_centripetal_shift(
  315. tl_centripetal_shift,
  316. gt_tl_centripetal_shift,
  317. tl_mask,
  318. avg_factor=tl_mask.sum())
  319. br_centripetal_loss = self.loss_centripetal_shift(
  320. br_centripetal_shift,
  321. gt_br_centripetal_shift,
  322. br_mask,
  323. avg_factor=br_mask.sum())
  324. centripetal_loss = (tl_centripetal_loss + br_centripetal_loss) / 2.0
  325. return det_loss, off_loss, guiding_loss, centripetal_loss
  326. def get_bboxes(self,
  327. tl_heats,
  328. br_heats,
  329. tl_offs,
  330. br_offs,
  331. tl_guiding_shifts,
  332. br_guiding_shifts,
  333. tl_centripetal_shifts,
  334. br_centripetal_shifts,
  335. img_metas,
  336. rescale=False,
  337. with_nms=True):
  338. """Transform network output for a batch into bbox predictions.
  339. Args:
  340. tl_heats (list[Tensor]): Top-left corner heatmaps for each level
  341. with shape (N, num_classes, H, W).
  342. br_heats (list[Tensor]): Bottom-right corner heatmaps for each
  343. level with shape (N, num_classes, H, W).
  344. tl_offs (list[Tensor]): Top-left corner offsets for each level
  345. with shape (N, corner_offset_channels, H, W).
  346. br_offs (list[Tensor]): Bottom-right corner offsets for each level
  347. with shape (N, corner_offset_channels, H, W).
  348. tl_guiding_shifts (list[Tensor]): Top-left guiding shifts for each
  349. level with shape (N, guiding_shift_channels, H, W). Useless in
  350. this function, we keep this arg because it's the raw output
  351. from CentripetalHead.
  352. br_guiding_shifts (list[Tensor]): Bottom-right guiding shifts for
  353. each level with shape (N, guiding_shift_channels, H, W).
  354. Useless in this function, we keep this arg because it's the
  355. raw output from CentripetalHead.
  356. tl_centripetal_shifts (list[Tensor]): Top-left centripetal shifts
  357. for each level with shape (N, centripetal_shift_channels, H,
  358. W).
  359. br_centripetal_shifts (list[Tensor]): Bottom-right centripetal
  360. shifts for each level with shape (N,
  361. centripetal_shift_channels, H, W).
  362. img_metas (list[dict]): Meta information of each image, e.g.,
  363. image size, scaling factor, etc.
  364. rescale (bool): If True, return boxes in original image space.
  365. Default: False.
  366. with_nms (bool): If True, do nms before return boxes.
  367. Default: True.
  368. """
  369. assert tl_heats[-1].shape[0] == br_heats[-1].shape[0] == len(img_metas)
  370. result_list = []
  371. for img_id in range(len(img_metas)):
  372. result_list.append(
  373. self._get_bboxes_single(
  374. tl_heats[-1][img_id:img_id + 1, :],
  375. br_heats[-1][img_id:img_id + 1, :],
  376. tl_offs[-1][img_id:img_id + 1, :],
  377. br_offs[-1][img_id:img_id + 1, :],
  378. img_metas[img_id],
  379. tl_emb=None,
  380. br_emb=None,
  381. tl_centripetal_shift=tl_centripetal_shifts[-1][
  382. img_id:img_id + 1, :],
  383. br_centripetal_shift=br_centripetal_shifts[-1][
  384. img_id:img_id + 1, :],
  385. rescale=rescale,
  386. with_nms=with_nms))
  387. return result_list

No Description

Contributors (3)