You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

yolox_head.py 21 kB

2 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482
  1. # Copyright (c) OpenMMLab. All rights reserved.
  2. import math
  3. import torch
  4. import torch.nn as nn
  5. import torch.nn.functional as F
  6. from mmcv.cnn import (ConvModule, DepthwiseSeparableConvModule,
  7. bias_init_with_prob)
  8. from mmcv.ops.nms import batched_nms
  9. from mmcv.runner import force_fp32
  10. from mmdet.core import (MlvlPointGenerator, bbox_xyxy_to_cxcywh,
  11. build_assigner, build_sampler, multi_apply)
  12. from ..builder import HEADS, build_loss
  13. from .base_dense_head import BaseDenseHead
  14. from .dense_test_mixins import BBoxTestMixin
  15. @HEADS.register_module()
  16. class YOLOXHead(BaseDenseHead, BBoxTestMixin):
  17. """YOLOXHead head used in `YOLOX <https://arxiv.org/abs/2107.08430>`_.
  18. Args:
  19. num_classes (int): Number of categories excluding the background
  20. category.
  21. in_channels (int): Number of channels in the input feature map.
  22. feat_channels (int): Number of hidden channels in stacking convs.
  23. Default: 256
  24. stacked_convs (int): Number of stacking convs of the head.
  25. Default: 2.
  26. strides (tuple): Downsample factor of each feature map.
  27. use_depthwise (bool): Whether to depthwise separable convolution in
  28. blocks. Default: False
  29. dcn_on_last_conv (bool): If true, use dcn in the last layer of
  30. towers. Default: False.
  31. conv_bias (bool | str): If specified as `auto`, it will be decided by
  32. the norm_cfg. Bias of conv will be set as True if `norm_cfg` is
  33. None, otherwise False. Default: "auto".
  34. conv_cfg (dict): Config dict for convolution layer. Default: None.
  35. norm_cfg (dict): Config dict for normalization layer. Default: None.
  36. act_cfg (dict): Config dict for activation layer. Default: None.
  37. loss_cls (dict): Config of classification loss.
  38. loss_bbox (dict): Config of localization loss.
  39. loss_obj (dict): Config of objectness loss.
  40. loss_l1 (dict): Config of L1 loss.
  41. train_cfg (dict): Training config of anchor head.
  42. test_cfg (dict): Testing config of anchor head.
  43. init_cfg (dict or list[dict], optional): Initialization config dict.
  44. """
  45. def __init__(self,
  46. num_classes,
  47. in_channels,
  48. feat_channels=256,
  49. stacked_convs=2,
  50. strides=[8, 16, 32],
  51. use_depthwise=False,
  52. dcn_on_last_conv=False,
  53. conv_bias='auto',
  54. conv_cfg=None,
  55. norm_cfg=dict(type='BN', momentum=0.03, eps=0.001),
  56. act_cfg=dict(type='Swish'),
  57. loss_cls=dict(
  58. type='CrossEntropyLoss',
  59. use_sigmoid=True,
  60. reduction='sum',
  61. loss_weight=1.0),
  62. loss_bbox=dict(
  63. type='IoULoss',
  64. mode='square',
  65. eps=1e-16,
  66. reduction='sum',
  67. loss_weight=5.0),
  68. loss_obj=dict(
  69. type='CrossEntropyLoss',
  70. use_sigmoid=True,
  71. reduction='sum',
  72. loss_weight=1.0),
  73. loss_l1=dict(type='L1Loss', reduction='sum', loss_weight=1.0),
  74. train_cfg=None,
  75. test_cfg=None,
  76. init_cfg=dict(
  77. type='Kaiming',
  78. layer='Conv2d',
  79. a=math.sqrt(5),
  80. distribution='uniform',
  81. mode='fan_in',
  82. nonlinearity='leaky_relu')):
  83. super().__init__(init_cfg=init_cfg)
  84. self.num_classes = num_classes
  85. self.cls_out_channels = num_classes
  86. self.in_channels = in_channels
  87. self.feat_channels = feat_channels
  88. self.stacked_convs = stacked_convs
  89. self.strides = strides
  90. self.use_depthwise = use_depthwise
  91. self.dcn_on_last_conv = dcn_on_last_conv
  92. assert conv_bias == 'auto' or isinstance(conv_bias, bool)
  93. self.conv_bias = conv_bias
  94. self.use_sigmoid_cls = True
  95. self.conv_cfg = conv_cfg
  96. self.norm_cfg = norm_cfg
  97. self.act_cfg = act_cfg
  98. self.loss_cls = build_loss(loss_cls)
  99. self.loss_bbox = build_loss(loss_bbox)
  100. self.loss_obj = build_loss(loss_obj)
  101. self.use_l1 = False # This flag will be modified by hooks.
  102. self.loss_l1 = build_loss(loss_l1)
  103. self.prior_generator = MlvlPointGenerator(strides, offset=0)
  104. self.test_cfg = test_cfg
  105. self.train_cfg = train_cfg
  106. self.sampling = False
  107. if self.train_cfg:
  108. self.assigner = build_assigner(self.train_cfg.assigner)
  109. # sampling=False so use PseudoSampler
  110. sampler_cfg = dict(type='PseudoSampler')
  111. self.sampler = build_sampler(sampler_cfg, context=self)
  112. self.fp16_enabled = False
  113. self._init_layers()
  114. def _init_layers(self):
  115. self.multi_level_cls_convs = nn.ModuleList()
  116. self.multi_level_reg_convs = nn.ModuleList()
  117. self.multi_level_conv_cls = nn.ModuleList()
  118. self.multi_level_conv_reg = nn.ModuleList()
  119. self.multi_level_conv_obj = nn.ModuleList()
  120. for _ in self.strides:
  121. self.multi_level_cls_convs.append(self._build_stacked_convs())
  122. self.multi_level_reg_convs.append(self._build_stacked_convs())
  123. conv_cls, conv_reg, conv_obj = self._build_predictor()
  124. self.multi_level_conv_cls.append(conv_cls)
  125. self.multi_level_conv_reg.append(conv_reg)
  126. self.multi_level_conv_obj.append(conv_obj)
  127. def _build_stacked_convs(self):
  128. """Initialize conv layers of a single level head."""
  129. conv = DepthwiseSeparableConvModule \
  130. if self.use_depthwise else ConvModule
  131. stacked_convs = []
  132. for i in range(self.stacked_convs):
  133. chn = self.in_channels if i == 0 else self.feat_channels
  134. if self.dcn_on_last_conv and i == self.stacked_convs - 1:
  135. conv_cfg = dict(type='DCNv2')
  136. else:
  137. conv_cfg = self.conv_cfg
  138. stacked_convs.append(
  139. conv(
  140. chn,
  141. self.feat_channels,
  142. 3,
  143. stride=1,
  144. padding=1,
  145. conv_cfg=conv_cfg,
  146. norm_cfg=self.norm_cfg,
  147. act_cfg=self.act_cfg,
  148. bias=self.conv_bias))
  149. return nn.Sequential(*stacked_convs)
  150. def _build_predictor(self):
  151. """Initialize predictor layers of a single level head."""
  152. conv_cls = nn.Conv2d(self.feat_channels, self.cls_out_channels, 1)
  153. conv_reg = nn.Conv2d(self.feat_channels, 4, 1)
  154. conv_obj = nn.Conv2d(self.feat_channels, 1, 1)
  155. return conv_cls, conv_reg, conv_obj
  156. def init_weights(self):
  157. super(YOLOXHead, self).init_weights()
  158. # Use prior in model initialization to improve stability
  159. bias_init = bias_init_with_prob(0.01)
  160. for conv_cls, conv_obj in zip(self.multi_level_conv_cls,
  161. self.multi_level_conv_obj):
  162. conv_cls.bias.data.fill_(bias_init)
  163. conv_obj.bias.data.fill_(bias_init)
  164. def forward_single(self, x, cls_convs, reg_convs, conv_cls, conv_reg,
  165. conv_obj):
  166. """Forward feature of a single scale level."""
  167. cls_feat = cls_convs(x)
  168. reg_feat = reg_convs(x)
  169. cls_score = conv_cls(cls_feat)
  170. bbox_pred = conv_reg(reg_feat)
  171. objectness = conv_obj(reg_feat)
  172. return cls_score, bbox_pred, objectness
  173. def forward(self, feats):
  174. """Forward features from the upstream network.
  175. Args:
  176. feats (tuple[Tensor]): Features from the upstream network, each is
  177. a 4D-tensor.
  178. Returns:
  179. tuple[Tensor]: A tuple of multi-level predication map, each is a
  180. 4D-tensor of shape (batch_size, 5+num_classes, height, width).
  181. """
  182. return multi_apply(self.forward_single, feats,
  183. self.multi_level_cls_convs,
  184. self.multi_level_reg_convs,
  185. self.multi_level_conv_cls,
  186. self.multi_level_conv_reg,
  187. self.multi_level_conv_obj)
  188. def get_bboxes(self,
  189. cls_scores,
  190. bbox_preds,
  191. objectnesses,
  192. img_metas=None,
  193. cfg=None,
  194. rescale=False,
  195. with_nms=True):
  196. """Transform network outputs of a batch into bbox results.
  197. Args:
  198. cls_scores (list[Tensor]): Classification scores for all
  199. scale levels, each is a 4D-tensor, has shape
  200. (batch_size, num_priors * num_classes, H, W).
  201. bbox_preds (list[Tensor]): Box energies / deltas for all
  202. scale levels, each is a 4D-tensor, has shape
  203. (batch_size, num_priors * 4, H, W).
  204. objectnesses (list[Tensor], Optional): Score factor for
  205. all scale level, each is a 4D-tensor, has shape
  206. (batch_size, 1, H, W).
  207. img_metas (list[dict], Optional): Image meta info. Default None.
  208. cfg (mmcv.Config, Optional): Test / postprocessing configuration,
  209. if None, test_cfg would be used. Default None.
  210. rescale (bool): If True, return boxes in original image space.
  211. Default False.
  212. with_nms (bool): If True, do nms before return boxes.
  213. Default True.
  214. Returns:
  215. list[list[Tensor, Tensor]]: Each item in result_list is 2-tuple.
  216. The first item is an (n, 5) tensor, where the first 4 columns
  217. are bounding box positions (tl_x, tl_y, br_x, br_y) and the
  218. 5-th column is a score between 0 and 1. The second item is a
  219. (n,) tensor where each item is the predicted class label of
  220. the corresponding box.
  221. """
  222. assert len(cls_scores) == len(bbox_preds) == len(objectnesses)
  223. cfg = self.test_cfg if cfg is None else cfg
  224. scale_factors = [img_meta['scale_factor'] for img_meta in img_metas]
  225. num_imgs = len(img_metas)
  226. featmap_sizes = [cls_score.shape[2:] for cls_score in cls_scores]
  227. mlvl_priors = self.prior_generator.grid_priors(
  228. featmap_sizes,
  229. dtype=cls_scores[0].dtype,
  230. device=cls_scores[0].device,
  231. with_stride=True)
  232. # flatten cls_scores, bbox_preds and objectness
  233. flatten_cls_scores = [
  234. cls_score.permute(0, 2, 3, 1).reshape(num_imgs, -1,
  235. self.cls_out_channels)
  236. for cls_score in cls_scores
  237. ]
  238. flatten_bbox_preds = [
  239. bbox_pred.permute(0, 2, 3, 1).reshape(num_imgs, -1, 4)
  240. for bbox_pred in bbox_preds
  241. ]
  242. flatten_objectness = [
  243. objectness.permute(0, 2, 3, 1).reshape(num_imgs, -1)
  244. for objectness in objectnesses
  245. ]
  246. flatten_cls_scores = torch.cat(flatten_cls_scores, dim=1).sigmoid()
  247. flatten_bbox_preds = torch.cat(flatten_bbox_preds, dim=1)
  248. flatten_objectness = torch.cat(flatten_objectness, dim=1).sigmoid()
  249. flatten_priors = torch.cat(mlvl_priors)
  250. flatten_bboxes = self._bbox_decode(flatten_priors, flatten_bbox_preds)
  251. if rescale:
  252. flatten_bboxes[..., :4] /= flatten_bboxes.new_tensor(
  253. scale_factors).unsqueeze(1)
  254. result_list = []
  255. for img_id in range(len(img_metas)):
  256. cls_scores = flatten_cls_scores[img_id]
  257. score_factor = flatten_objectness[img_id]
  258. bboxes = flatten_bboxes[img_id]
  259. result_list.append(
  260. self._bboxes_nms(cls_scores, bboxes, score_factor, cfg))
  261. return result_list
  262. def _bbox_decode(self, priors, bbox_preds):
  263. xys = (bbox_preds[..., :2] * priors[:, 2:]) + priors[:, :2]
  264. whs = bbox_preds[..., 2:].exp() * priors[:, 2:]
  265. tl_x = (xys[..., 0] - whs[..., 0] / 2)
  266. tl_y = (xys[..., 1] - whs[..., 1] / 2)
  267. br_x = (xys[..., 0] + whs[..., 0] / 2)
  268. br_y = (xys[..., 1] + whs[..., 1] / 2)
  269. decoded_bboxes = torch.stack([tl_x, tl_y, br_x, br_y], -1)
  270. return decoded_bboxes
  271. def _bboxes_nms(self, cls_scores, bboxes, score_factor, cfg):
  272. max_scores, labels = torch.max(cls_scores, 1)
  273. valid_mask = score_factor * max_scores >= cfg.score_thr
  274. bboxes = bboxes[valid_mask]
  275. scores = max_scores[valid_mask] * score_factor[valid_mask]
  276. labels = labels[valid_mask]
  277. if labels.numel() == 0:
  278. return bboxes, labels
  279. else:
  280. dets, keep = batched_nms(bboxes, scores, labels, cfg.nms)
  281. return dets, labels[keep]
  282. @force_fp32(apply_to=('cls_scores', 'bbox_preds', 'objectnesses'))
  283. def loss(self,
  284. cls_scores,
  285. bbox_preds,
  286. objectnesses,
  287. gt_bboxes,
  288. gt_labels,
  289. img_metas,
  290. gt_bboxes_ignore=None):
  291. """Compute loss of the head.
  292. Args:
  293. cls_scores (list[Tensor]): Box scores for each scale level,
  294. each is a 4D-tensor, the channel number is
  295. num_priors * num_classes.
  296. bbox_preds (list[Tensor]): Box energies / deltas for each scale
  297. level, each is a 4D-tensor, the channel number is
  298. num_priors * 4.
  299. objectnesses (list[Tensor], Optional): Score factor for
  300. all scale level, each is a 4D-tensor, has shape
  301. (batch_size, 1, H, W).
  302. gt_bboxes (list[Tensor]): Ground truth bboxes for each image with
  303. shape (num_gts, 4) in [tl_x, tl_y, br_x, br_y] format.
  304. gt_labels (list[Tensor]): class indices corresponding to each box
  305. img_metas (list[dict]): Meta information of each image, e.g.,
  306. image size, scaling factor, etc.
  307. gt_bboxes_ignore (None | list[Tensor]): specify which bounding
  308. boxes can be ignored when computing the loss.
  309. """
  310. num_imgs = len(img_metas)
  311. featmap_sizes = [cls_score.shape[2:] for cls_score in cls_scores]
  312. mlvl_priors = self.prior_generator.grid_priors(
  313. featmap_sizes,
  314. dtype=cls_scores[0].dtype,
  315. device=cls_scores[0].device,
  316. with_stride=True)
  317. flatten_cls_preds = [
  318. cls_pred.permute(0, 2, 3, 1).reshape(num_imgs, -1,
  319. self.cls_out_channels)
  320. for cls_pred in cls_scores
  321. ]
  322. flatten_bbox_preds = [
  323. bbox_pred.permute(0, 2, 3, 1).reshape(num_imgs, -1, 4)
  324. for bbox_pred in bbox_preds
  325. ]
  326. flatten_objectness = [
  327. objectness.permute(0, 2, 3, 1).reshape(num_imgs, -1)
  328. for objectness in objectnesses
  329. ]
  330. flatten_cls_preds = torch.cat(flatten_cls_preds, dim=1)
  331. flatten_bbox_preds = torch.cat(flatten_bbox_preds, dim=1)
  332. flatten_objectness = torch.cat(flatten_objectness, dim=1)
  333. flatten_priors = torch.cat(mlvl_priors)
  334. flatten_bboxes = self._bbox_decode(flatten_priors, flatten_bbox_preds)
  335. (pos_masks, cls_targets, obj_targets, bbox_targets, l1_targets,
  336. num_fg_imgs) = multi_apply(
  337. self._get_target_single, flatten_cls_preds.detach(),
  338. flatten_objectness.detach(),
  339. flatten_priors.unsqueeze(0).repeat(num_imgs, 1, 1),
  340. flatten_bboxes.detach(), gt_bboxes, gt_labels)
  341. num_total_samples = max(sum(num_fg_imgs), 1)
  342. pos_masks = torch.cat(pos_masks, 0)
  343. cls_targets = torch.cat(cls_targets, 0)
  344. obj_targets = torch.cat(obj_targets, 0)
  345. bbox_targets = torch.cat(bbox_targets, 0)
  346. if self.use_l1:
  347. l1_targets = torch.cat(l1_targets, 0)
  348. loss_bbox = self.loss_bbox(
  349. flatten_bboxes.view(-1, 4)[pos_masks],
  350. bbox_targets) / num_total_samples
  351. loss_obj = self.loss_obj(flatten_objectness.view(-1, 1),
  352. obj_targets) / num_total_samples
  353. loss_cls = self.loss_cls(
  354. flatten_cls_preds.view(-1, self.num_classes)[pos_masks],
  355. cls_targets) / num_total_samples
  356. loss_dict = dict(
  357. loss_cls=loss_cls, loss_bbox=loss_bbox, loss_obj=loss_obj)
  358. if self.use_l1:
  359. loss_l1 = self.loss_l1(
  360. flatten_bbox_preds.view(-1, 4)[pos_masks],
  361. l1_targets) / num_total_samples
  362. loss_dict.update(loss_l1=loss_l1)
  363. return loss_dict
  364. @torch.no_grad()
  365. def _get_target_single(self, cls_preds, objectness, priors, decoded_bboxes,
  366. gt_bboxes, gt_labels):
  367. """Compute classification, regression, and objectness targets for
  368. priors in a single image.
  369. Args:
  370. cls_preds (Tensor): Classification predictions of one image,
  371. a 2D-Tensor with shape [num_priors, num_classes]
  372. objectness (Tensor): Objectness predictions of one image,
  373. a 1D-Tensor with shape [num_priors]
  374. priors (Tensor): All priors of one image, a 2D-Tensor with shape
  375. [num_priors, 4] in [cx, xy, stride_w, stride_y] format.
  376. decoded_bboxes (Tensor): Decoded bboxes predictions of one image,
  377. a 2D-Tensor with shape [num_priors, 4] in [tl_x, tl_y,
  378. br_x, br_y] format.
  379. gt_bboxes (Tensor): Ground truth bboxes of one image, a 2D-Tensor
  380. with shape [num_gts, 4] in [tl_x, tl_y, br_x, br_y] format.
  381. gt_labels (Tensor): Ground truth labels of one image, a Tensor
  382. with shape [num_gts].
  383. """
  384. num_priors = priors.size(0)
  385. num_gts = gt_labels.size(0)
  386. gt_bboxes = gt_bboxes.to(decoded_bboxes.dtype)
  387. # No target
  388. if num_gts == 0:
  389. cls_target = cls_preds.new_zeros((0, self.num_classes))
  390. bbox_target = cls_preds.new_zeros((0, 4))
  391. l1_target = cls_preds.new_zeros((0, 4))
  392. obj_target = cls_preds.new_zeros((num_priors, 1))
  393. foreground_mask = cls_preds.new_zeros(num_priors).bool()
  394. return (foreground_mask, cls_target, obj_target, bbox_target,
  395. l1_target, 0)
  396. # YOLOX uses center priors with 0.5 offset to assign targets,
  397. # but use center priors without offset to regress bboxes.
  398. offset_priors = torch.cat(
  399. [priors[:, :2] + priors[:, 2:] * 0.5, priors[:, 2:]], dim=-1)
  400. assign_result = self.assigner.assign(
  401. cls_preds.sigmoid() * objectness.unsqueeze(1).sigmoid(),
  402. offset_priors, decoded_bboxes, gt_bboxes, gt_labels)
  403. sampling_result = self.sampler.sample(assign_result, priors, gt_bboxes)
  404. pos_inds = sampling_result.pos_inds
  405. num_pos_per_img = pos_inds.size(0)
  406. pos_ious = assign_result.max_overlaps[pos_inds]
  407. # IOU aware classification score
  408. cls_target = F.one_hot(sampling_result.pos_gt_labels,
  409. self.num_classes) * pos_ious.unsqueeze(-1)
  410. obj_target = torch.zeros_like(objectness).unsqueeze(-1)
  411. obj_target[pos_inds] = 1
  412. bbox_target = sampling_result.pos_gt_bboxes
  413. l1_target = cls_preds.new_zeros((num_pos_per_img, 4))
  414. if self.use_l1:
  415. l1_target = self._get_l1_target(l1_target, bbox_target,
  416. priors[pos_inds])
  417. foreground_mask = torch.zeros_like(objectness).to(torch.bool)
  418. foreground_mask[pos_inds] = 1
  419. return (foreground_mask, cls_target, obj_target, bbox_target,
  420. l1_target, num_pos_per_img)
  421. def _get_l1_target(self, l1_target, gt_bboxes, priors, eps=1e-8):
  422. """Convert gt bboxes to center offset and log width height."""
  423. gt_cxcywh = bbox_xyxy_to_cxcywh(gt_bboxes)
  424. l1_target[:, :2] = (gt_cxcywh[:, :2] - priors[:, :2]) / priors[:, 2:]
  425. l1_target[:, 2:] = torch.log(gt_cxcywh[:, 2:] / priors[:, 2:] + eps)
  426. return l1_target

No Description

Contributors (3)