You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

centernet_head.py 18 kB

2 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412
  1. # Copyright (c) OpenMMLab. All rights reserved.
  2. import torch
  3. import torch.nn as nn
  4. from mmcv.cnn import bias_init_with_prob, normal_init
  5. from mmcv.ops import batched_nms
  6. from mmcv.runner import force_fp32
  7. from mmdet.core import multi_apply
  8. from mmdet.models import HEADS, build_loss
  9. from mmdet.models.utils import gaussian_radius, gen_gaussian_target
  10. from ..utils.gaussian_target import (get_local_maximum, get_topk_from_heatmap,
  11. transpose_and_gather_feat)
  12. from .base_dense_head import BaseDenseHead
  13. from .dense_test_mixins import BBoxTestMixin
  14. @HEADS.register_module()
  15. class CenterNetHead(BaseDenseHead, BBoxTestMixin):
  16. """Objects as Points Head. CenterHead use center_point to indicate object's
  17. position. Paper link <https://arxiv.org/abs/1904.07850>
  18. Args:
  19. in_channel (int): Number of channel in the input feature map.
  20. feat_channel (int): Number of channel in the intermediate feature map.
  21. num_classes (int): Number of categories excluding the background
  22. category.
  23. loss_center_heatmap (dict | None): Config of center heatmap loss.
  24. Default: GaussianFocalLoss.
  25. loss_wh (dict | None): Config of wh loss. Default: L1Loss.
  26. loss_offset (dict | None): Config of offset loss. Default: L1Loss.
  27. train_cfg (dict | None): Training config. Useless in CenterNet,
  28. but we keep this variable for SingleStageDetector. Default: None.
  29. test_cfg (dict | None): Testing config of CenterNet. Default: None.
  30. init_cfg (dict or list[dict], optional): Initialization config dict.
  31. Default: None
  32. """
  33. def __init__(self,
  34. in_channel,
  35. feat_channel,
  36. num_classes,
  37. loss_center_heatmap=dict(
  38. type='GaussianFocalLoss', loss_weight=1.0),
  39. loss_wh=dict(type='L1Loss', loss_weight=0.1),
  40. loss_offset=dict(type='L1Loss', loss_weight=1.0),
  41. train_cfg=None,
  42. test_cfg=None,
  43. init_cfg=None):
  44. super(CenterNetHead, self).__init__(init_cfg)
  45. self.num_classes = num_classes
  46. self.heatmap_head = self._build_head(in_channel, feat_channel,
  47. num_classes)
  48. self.wh_head = self._build_head(in_channel, feat_channel, 2)
  49. self.offset_head = self._build_head(in_channel, feat_channel, 2)
  50. self.loss_center_heatmap = build_loss(loss_center_heatmap)
  51. self.loss_wh = build_loss(loss_wh)
  52. self.loss_offset = build_loss(loss_offset)
  53. self.train_cfg = train_cfg
  54. self.test_cfg = test_cfg
  55. self.fp16_enabled = False
  56. def _build_head(self, in_channel, feat_channel, out_channel):
  57. """Build head for each branch."""
  58. layer = nn.Sequential(
  59. nn.Conv2d(in_channel, feat_channel, kernel_size=3, padding=1),
  60. nn.ReLU(inplace=True),
  61. nn.Conv2d(feat_channel, out_channel, kernel_size=1))
  62. return layer
  63. def init_weights(self):
  64. """Initialize weights of the head."""
  65. bias_init = bias_init_with_prob(0.1)
  66. self.heatmap_head[-1].bias.data.fill_(bias_init)
  67. for head in [self.wh_head, self.offset_head]:
  68. for m in head.modules():
  69. if isinstance(m, nn.Conv2d):
  70. normal_init(m, std=0.001)
  71. def forward(self, feats):
  72. """Forward features. Notice CenterNet head does not use FPN.
  73. Args:
  74. feats (tuple[Tensor]): Features from the upstream network, each is
  75. a 4D-tensor.
  76. Returns:
  77. center_heatmap_preds (List[Tensor]): center predict heatmaps for
  78. all levels, the channels number is num_classes.
  79. wh_preds (List[Tensor]): wh predicts for all levels, the channels
  80. number is 2.
  81. offset_preds (List[Tensor]): offset predicts for all levels, the
  82. channels number is 2.
  83. """
  84. return multi_apply(self.forward_single, feats)
  85. def forward_single(self, feat):
  86. """Forward feature of a single level.
  87. Args:
  88. feat (Tensor): Feature of a single level.
  89. Returns:
  90. center_heatmap_pred (Tensor): center predict heatmaps, the
  91. channels number is num_classes.
  92. wh_pred (Tensor): wh predicts, the channels number is 2.
  93. offset_pred (Tensor): offset predicts, the channels number is 2.
  94. """
  95. center_heatmap_pred = self.heatmap_head(feat).sigmoid()
  96. wh_pred = self.wh_head(feat)
  97. offset_pred = self.offset_head(feat)
  98. return center_heatmap_pred, wh_pred, offset_pred
  99. @force_fp32(apply_to=('center_heatmap_preds', 'wh_preds', 'offset_preds'))
  100. def loss(self,
  101. center_heatmap_preds,
  102. wh_preds,
  103. offset_preds,
  104. gt_bboxes,
  105. gt_labels,
  106. img_metas,
  107. gt_bboxes_ignore=None):
  108. """Compute losses of the head.
  109. Args:
  110. center_heatmap_preds (list[Tensor]): center predict heatmaps for
  111. all levels with shape (B, num_classes, H, W).
  112. wh_preds (list[Tensor]): wh predicts for all levels with
  113. shape (B, 2, H, W).
  114. offset_preds (list[Tensor]): offset predicts for all levels
  115. with shape (B, 2, H, W).
  116. gt_bboxes (list[Tensor]): Ground truth bboxes for each image with
  117. shape (num_gts, 4) in [tl_x, tl_y, br_x, br_y] format.
  118. gt_labels (list[Tensor]): class indices corresponding to each box.
  119. img_metas (list[dict]): Meta information of each image, e.g.,
  120. image size, scaling factor, etc.
  121. gt_bboxes_ignore (None | list[Tensor]): specify which bounding
  122. boxes can be ignored when computing the loss. Default: None
  123. Returns:
  124. dict[str, Tensor]: which has components below:
  125. - loss_center_heatmap (Tensor): loss of center heatmap.
  126. - loss_wh (Tensor): loss of hw heatmap
  127. - loss_offset (Tensor): loss of offset heatmap.
  128. """
  129. assert len(center_heatmap_preds) == len(wh_preds) == len(
  130. offset_preds) == 1
  131. center_heatmap_pred = center_heatmap_preds[0]
  132. wh_pred = wh_preds[0]
  133. offset_pred = offset_preds[0]
  134. target_result, avg_factor = self.get_targets(gt_bboxes, gt_labels,
  135. center_heatmap_pred.shape,
  136. img_metas[0]['pad_shape'])
  137. center_heatmap_target = target_result['center_heatmap_target']
  138. wh_target = target_result['wh_target']
  139. offset_target = target_result['offset_target']
  140. wh_offset_target_weight = target_result['wh_offset_target_weight']
  141. # Since the channel of wh_target and offset_target is 2, the avg_factor
  142. # of loss_center_heatmap is always 1/2 of loss_wh and loss_offset.
  143. loss_center_heatmap = self.loss_center_heatmap(
  144. center_heatmap_pred, center_heatmap_target, avg_factor=avg_factor)
  145. loss_wh = self.loss_wh(
  146. wh_pred,
  147. wh_target,
  148. wh_offset_target_weight,
  149. avg_factor=avg_factor * 2)
  150. loss_offset = self.loss_offset(
  151. offset_pred,
  152. offset_target,
  153. wh_offset_target_weight,
  154. avg_factor=avg_factor * 2)
  155. return dict(
  156. loss_center_heatmap=loss_center_heatmap,
  157. loss_wh=loss_wh,
  158. loss_offset=loss_offset)
  159. def get_targets(self, gt_bboxes, gt_labels, feat_shape, img_shape):
  160. """Compute regression and classification targets in multiple images.
  161. Args:
  162. gt_bboxes (list[Tensor]): Ground truth bboxes for each image with
  163. shape (num_gts, 4) in [tl_x, tl_y, br_x, br_y] format.
  164. gt_labels (list[Tensor]): class indices corresponding to each box.
  165. feat_shape (list[int]): feature map shape with value [B, _, H, W]
  166. img_shape (list[int]): image shape in [h, w] format.
  167. Returns:
  168. tuple[dict,float]: The float value is mean avg_factor, the dict has
  169. components below:
  170. - center_heatmap_target (Tensor): targets of center heatmap, \
  171. shape (B, num_classes, H, W).
  172. - wh_target (Tensor): targets of wh predict, shape \
  173. (B, 2, H, W).
  174. - offset_target (Tensor): targets of offset predict, shape \
  175. (B, 2, H, W).
  176. - wh_offset_target_weight (Tensor): weights of wh and offset \
  177. predict, shape (B, 2, H, W).
  178. """
  179. img_h, img_w = img_shape[:2]
  180. bs, _, feat_h, feat_w = feat_shape
  181. width_ratio = float(feat_w / img_w)
  182. height_ratio = float(feat_h / img_h)
  183. center_heatmap_target = gt_bboxes[-1].new_zeros(
  184. [bs, self.num_classes, feat_h, feat_w])
  185. wh_target = gt_bboxes[-1].new_zeros([bs, 2, feat_h, feat_w])
  186. offset_target = gt_bboxes[-1].new_zeros([bs, 2, feat_h, feat_w])
  187. wh_offset_target_weight = gt_bboxes[-1].new_zeros(
  188. [bs, 2, feat_h, feat_w])
  189. for batch_id in range(bs):
  190. gt_bbox = gt_bboxes[batch_id]
  191. gt_label = gt_labels[batch_id]
  192. center_x = (gt_bbox[:, [0]] + gt_bbox[:, [2]]) * width_ratio / 2
  193. center_y = (gt_bbox[:, [1]] + gt_bbox[:, [3]]) * height_ratio / 2
  194. gt_centers = torch.cat((center_x, center_y), dim=1)
  195. for j, ct in enumerate(gt_centers):
  196. ctx_int, cty_int = ct.int()
  197. ctx, cty = ct
  198. scale_box_h = (gt_bbox[j][3] - gt_bbox[j][1]) * height_ratio
  199. scale_box_w = (gt_bbox[j][2] - gt_bbox[j][0]) * width_ratio
  200. radius = gaussian_radius([scale_box_h, scale_box_w],
  201. min_overlap=0.3)
  202. radius = max(0, int(radius))
  203. ind = gt_label[j]
  204. gen_gaussian_target(center_heatmap_target[batch_id, ind],
  205. [ctx_int, cty_int], radius)
  206. wh_target[batch_id, 0, cty_int, ctx_int] = scale_box_w
  207. wh_target[batch_id, 1, cty_int, ctx_int] = scale_box_h
  208. offset_target[batch_id, 0, cty_int, ctx_int] = ctx - ctx_int
  209. offset_target[batch_id, 1, cty_int, ctx_int] = cty - cty_int
  210. wh_offset_target_weight[batch_id, :, cty_int, ctx_int] = 1
  211. avg_factor = max(1, center_heatmap_target.eq(1).sum())
  212. target_result = dict(
  213. center_heatmap_target=center_heatmap_target,
  214. wh_target=wh_target,
  215. offset_target=offset_target,
  216. wh_offset_target_weight=wh_offset_target_weight)
  217. return target_result, avg_factor
  218. @force_fp32(apply_to=('center_heatmap_preds', 'wh_preds', 'offset_preds'))
  219. def get_bboxes(self,
  220. center_heatmap_preds,
  221. wh_preds,
  222. offset_preds,
  223. img_metas,
  224. rescale=True,
  225. with_nms=False):
  226. """Transform network output for a batch into bbox predictions.
  227. Args:
  228. center_heatmap_preds (list[Tensor]): Center predict heatmaps for
  229. all levels with shape (B, num_classes, H, W).
  230. wh_preds (list[Tensor]): WH predicts for all levels with
  231. shape (B, 2, H, W).
  232. offset_preds (list[Tensor]): Offset predicts for all levels
  233. with shape (B, 2, H, W).
  234. img_metas (list[dict]): Meta information of each image, e.g.,
  235. image size, scaling factor, etc.
  236. rescale (bool): If True, return boxes in original image space.
  237. Default: True.
  238. with_nms (bool): If True, do nms before return boxes.
  239. Default: False.
  240. Returns:
  241. list[tuple[Tensor, Tensor]]: Each item in result_list is 2-tuple.
  242. The first item is an (n, 5) tensor, where 5 represent
  243. (tl_x, tl_y, br_x, br_y, score) and the score between 0 and 1.
  244. The shape of the second tensor in the tuple is (n,), and
  245. each element represents the class label of the corresponding
  246. box.
  247. """
  248. assert len(center_heatmap_preds) == len(wh_preds) == len(
  249. offset_preds) == 1
  250. result_list = []
  251. for img_id in range(len(img_metas)):
  252. result_list.append(
  253. self._get_bboxes_single(
  254. center_heatmap_preds[0][img_id:img_id + 1, ...],
  255. wh_preds[0][img_id:img_id + 1, ...],
  256. offset_preds[0][img_id:img_id + 1, ...],
  257. img_metas[img_id],
  258. rescale=rescale,
  259. with_nms=with_nms))
  260. return result_list
  261. def _get_bboxes_single(self,
  262. center_heatmap_pred,
  263. wh_pred,
  264. offset_pred,
  265. img_meta,
  266. rescale=False,
  267. with_nms=True):
  268. """Transform outputs of a single image into bbox results.
  269. Args:
  270. center_heatmap_pred (Tensor): Center heatmap for current level with
  271. shape (1, num_classes, H, W).
  272. wh_pred (Tensor): WH heatmap for current level with shape
  273. (1, num_classes, H, W).
  274. offset_pred (Tensor): Offset for current level with shape
  275. (1, corner_offset_channels, H, W).
  276. img_meta (dict): Meta information of current image, e.g.,
  277. image size, scaling factor, etc.
  278. rescale (bool): If True, return boxes in original image space.
  279. Default: False.
  280. with_nms (bool): If True, do nms before return boxes.
  281. Default: True.
  282. Returns:
  283. tuple[Tensor, Tensor]: The first item is an (n, 5) tensor, where
  284. 5 represent (tl_x, tl_y, br_x, br_y, score) and the score
  285. between 0 and 1. The shape of the second tensor in the tuple
  286. is (n,), and each element represents the class label of the
  287. corresponding box.
  288. """
  289. batch_det_bboxes, batch_labels = self.decode_heatmap(
  290. center_heatmap_pred,
  291. wh_pred,
  292. offset_pred,
  293. img_meta['batch_input_shape'],
  294. k=self.test_cfg.topk,
  295. kernel=self.test_cfg.local_maximum_kernel)
  296. det_bboxes = batch_det_bboxes.view([-1, 5])
  297. det_labels = batch_labels.view(-1)
  298. batch_border = det_bboxes.new_tensor(img_meta['border'])[...,
  299. [2, 0, 2, 0]]
  300. det_bboxes[..., :4] -= batch_border
  301. if rescale:
  302. det_bboxes[..., :4] /= det_bboxes.new_tensor(
  303. img_meta['scale_factor'])
  304. if with_nms:
  305. det_bboxes, det_labels = self._bboxes_nms(det_bboxes, det_labels,
  306. self.test_cfg)
  307. return det_bboxes, det_labels
  308. def decode_heatmap(self,
  309. center_heatmap_pred,
  310. wh_pred,
  311. offset_pred,
  312. img_shape,
  313. k=100,
  314. kernel=3):
  315. """Transform outputs into detections raw bbox prediction.
  316. Args:
  317. center_heatmap_pred (Tensor): center predict heatmap,
  318. shape (B, num_classes, H, W).
  319. wh_pred (Tensor): wh predict, shape (B, 2, H, W).
  320. offset_pred (Tensor): offset predict, shape (B, 2, H, W).
  321. img_shape (list[int]): image shape in [h, w] format.
  322. k (int): Get top k center keypoints from heatmap. Default 100.
  323. kernel (int): Max pooling kernel for extract local maximum pixels.
  324. Default 3.
  325. Returns:
  326. tuple[torch.Tensor]: Decoded output of CenterNetHead, containing
  327. the following Tensors:
  328. - batch_bboxes (Tensor): Coords of each box with shape (B, k, 5)
  329. - batch_topk_labels (Tensor): Categories of each box with \
  330. shape (B, k)
  331. """
  332. height, width = center_heatmap_pred.shape[2:]
  333. inp_h, inp_w = img_shape
  334. center_heatmap_pred = get_local_maximum(
  335. center_heatmap_pred, kernel=kernel)
  336. *batch_dets, topk_ys, topk_xs = get_topk_from_heatmap(
  337. center_heatmap_pred, k=k)
  338. batch_scores, batch_index, batch_topk_labels = batch_dets
  339. wh = transpose_and_gather_feat(wh_pred, batch_index)
  340. offset = transpose_and_gather_feat(offset_pred, batch_index)
  341. topk_xs = topk_xs + offset[..., 0]
  342. topk_ys = topk_ys + offset[..., 1]
  343. tl_x = (topk_xs - wh[..., 0] / 2) * (inp_w / width)
  344. tl_y = (topk_ys - wh[..., 1] / 2) * (inp_h / height)
  345. br_x = (topk_xs + wh[..., 0] / 2) * (inp_w / width)
  346. br_y = (topk_ys + wh[..., 1] / 2) * (inp_h / height)
  347. batch_bboxes = torch.stack([tl_x, tl_y, br_x, br_y], dim=2)
  348. batch_bboxes = torch.cat((batch_bboxes, batch_scores[..., None]),
  349. dim=-1)
  350. return batch_bboxes, batch_topk_labels
  351. def _bboxes_nms(self, bboxes, labels, cfg):
  352. if labels.numel() > 0:
  353. max_num = cfg.max_per_img
  354. bboxes, keep = batched_nms(bboxes[:, :4], bboxes[:,
  355. -1].contiguous(),
  356. labels, cfg.nms)
  357. if max_num > 0:
  358. bboxes = bboxes[:max_num]
  359. labels = labels[keep][:max_num]
  360. return bboxes, labels

No Description

Contributors (2)