You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

ssd_head.py 15 kB

2 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357
  1. # Copyright (c) OpenMMLab. All rights reserved.
  2. import warnings
  3. import torch
  4. import torch.nn as nn
  5. import torch.nn.functional as F
  6. from mmcv.cnn import ConvModule, DepthwiseSeparableConvModule
  7. from mmcv.runner import force_fp32
  8. from mmdet.core import (build_assigner, build_bbox_coder,
  9. build_prior_generator, build_sampler, multi_apply)
  10. from ..builder import HEADS
  11. from ..losses import smooth_l1_loss
  12. from .anchor_head import AnchorHead
  13. # TODO: add loss evaluator for SSD
  14. @HEADS.register_module()
  15. class SSDHead(AnchorHead):
  16. """SSD head used in https://arxiv.org/abs/1512.02325.
  17. Args:
  18. num_classes (int): Number of categories excluding the background
  19. category.
  20. in_channels (int): Number of channels in the input feature map.
  21. stacked_convs (int): Number of conv layers in cls and reg tower.
  22. Default: 0.
  23. feat_channels (int): Number of hidden channels when stacked_convs
  24. > 0. Default: 256.
  25. use_depthwise (bool): Whether to use DepthwiseSeparableConv.
  26. Default: False.
  27. conv_cfg (dict): Dictionary to construct and config conv layer.
  28. Default: None.
  29. norm_cfg (dict): Dictionary to construct and config norm layer.
  30. Default: None.
  31. act_cfg (dict): Dictionary to construct and config activation layer.
  32. Default: None.
  33. anchor_generator (dict): Config dict for anchor generator
  34. bbox_coder (dict): Config of bounding box coder.
  35. reg_decoded_bbox (bool): If true, the regression loss would be
  36. applied directly on decoded bounding boxes, converting both
  37. the predicted boxes and regression targets to absolute
  38. coordinates format. Default False. It should be `True` when
  39. using `IoULoss`, `GIoULoss`, or `DIoULoss` in the bbox head.
  40. train_cfg (dict): Training config of anchor head.
  41. test_cfg (dict): Testing config of anchor head.
  42. init_cfg (dict or list[dict], optional): Initialization config dict.
  43. """ # noqa: W605
  44. def __init__(self,
  45. num_classes=80,
  46. in_channels=(512, 1024, 512, 256, 256, 256),
  47. stacked_convs=0,
  48. feat_channels=256,
  49. use_depthwise=False,
  50. conv_cfg=None,
  51. norm_cfg=None,
  52. act_cfg=None,
  53. anchor_generator=dict(
  54. type='SSDAnchorGenerator',
  55. scale_major=False,
  56. input_size=300,
  57. strides=[8, 16, 32, 64, 100, 300],
  58. ratios=([2], [2, 3], [2, 3], [2, 3], [2], [2]),
  59. basesize_ratio_range=(0.1, 0.9)),
  60. bbox_coder=dict(
  61. type='DeltaXYWHBBoxCoder',
  62. clip_border=True,
  63. target_means=[.0, .0, .0, .0],
  64. target_stds=[1.0, 1.0, 1.0, 1.0],
  65. ),
  66. reg_decoded_bbox=False,
  67. train_cfg=None,
  68. test_cfg=None,
  69. init_cfg=dict(
  70. type='Xavier',
  71. layer='Conv2d',
  72. distribution='uniform',
  73. bias=0)):
  74. super(AnchorHead, self).__init__(init_cfg)
  75. self.num_classes = num_classes
  76. self.in_channels = in_channels
  77. self.stacked_convs = stacked_convs
  78. self.feat_channels = feat_channels
  79. self.use_depthwise = use_depthwise
  80. self.conv_cfg = conv_cfg
  81. self.norm_cfg = norm_cfg
  82. self.act_cfg = act_cfg
  83. self.cls_out_channels = num_classes + 1 # add background class
  84. self.prior_generator = build_prior_generator(anchor_generator)
  85. # Usually the numbers of anchors for each level are the same
  86. # except SSD detectors. So it is an int in the most dense
  87. # heads but a list of int in SSDHead
  88. self.num_base_priors = self.prior_generator.num_base_priors
  89. self._init_layers()
  90. self.bbox_coder = build_bbox_coder(bbox_coder)
  91. self.reg_decoded_bbox = reg_decoded_bbox
  92. self.use_sigmoid_cls = False
  93. self.cls_focal_loss = False
  94. self.train_cfg = train_cfg
  95. self.test_cfg = test_cfg
  96. # set sampling=False for archor_target
  97. self.sampling = False
  98. if self.train_cfg:
  99. self.assigner = build_assigner(self.train_cfg.assigner)
  100. # SSD sampling=False so use PseudoSampler
  101. sampler_cfg = dict(type='PseudoSampler')
  102. self.sampler = build_sampler(sampler_cfg, context=self)
  103. self.fp16_enabled = False
  104. @property
  105. def num_anchors(self):
  106. """
  107. Returns:
  108. list[int]: Number of base_anchors on each point of each level.
  109. """
  110. warnings.warn('DeprecationWarning: `num_anchors` is deprecated, '
  111. 'please use "num_base_priors" instead')
  112. return self.num_base_priors
  113. def _init_layers(self):
  114. """Initialize layers of the head."""
  115. self.cls_convs = nn.ModuleList()
  116. self.reg_convs = nn.ModuleList()
  117. # TODO: Use registry to choose ConvModule type
  118. conv = DepthwiseSeparableConvModule \
  119. if self.use_depthwise else ConvModule
  120. for channel, num_base_priors in zip(self.in_channels,
  121. self.num_base_priors):
  122. cls_layers = []
  123. reg_layers = []
  124. in_channel = channel
  125. # build stacked conv tower, not used in default ssd
  126. for i in range(self.stacked_convs):
  127. cls_layers.append(
  128. conv(
  129. in_channel,
  130. self.feat_channels,
  131. 3,
  132. padding=1,
  133. conv_cfg=self.conv_cfg,
  134. norm_cfg=self.norm_cfg,
  135. act_cfg=self.act_cfg))
  136. reg_layers.append(
  137. conv(
  138. in_channel,
  139. self.feat_channels,
  140. 3,
  141. padding=1,
  142. conv_cfg=self.conv_cfg,
  143. norm_cfg=self.norm_cfg,
  144. act_cfg=self.act_cfg))
  145. in_channel = self.feat_channels
  146. # SSD-Lite head
  147. if self.use_depthwise:
  148. cls_layers.append(
  149. ConvModule(
  150. in_channel,
  151. in_channel,
  152. 3,
  153. padding=1,
  154. groups=in_channel,
  155. conv_cfg=self.conv_cfg,
  156. norm_cfg=self.norm_cfg,
  157. act_cfg=self.act_cfg))
  158. reg_layers.append(
  159. ConvModule(
  160. in_channel,
  161. in_channel,
  162. 3,
  163. padding=1,
  164. groups=in_channel,
  165. conv_cfg=self.conv_cfg,
  166. norm_cfg=self.norm_cfg,
  167. act_cfg=self.act_cfg))
  168. cls_layers.append(
  169. nn.Conv2d(
  170. in_channel,
  171. num_base_priors * self.cls_out_channels,
  172. kernel_size=1 if self.use_depthwise else 3,
  173. padding=0 if self.use_depthwise else 1))
  174. reg_layers.append(
  175. nn.Conv2d(
  176. in_channel,
  177. num_base_priors * 4,
  178. kernel_size=1 if self.use_depthwise else 3,
  179. padding=0 if self.use_depthwise else 1))
  180. self.cls_convs.append(nn.Sequential(*cls_layers))
  181. self.reg_convs.append(nn.Sequential(*reg_layers))
  182. def forward(self, feats):
  183. """Forward features from the upstream network.
  184. Args:
  185. feats (tuple[Tensor]): Features from the upstream network, each is
  186. a 4D-tensor.
  187. Returns:
  188. tuple:
  189. cls_scores (list[Tensor]): Classification scores for all scale
  190. levels, each is a 4D-tensor, the channels number is
  191. num_anchors * num_classes.
  192. bbox_preds (list[Tensor]): Box energies / deltas for all scale
  193. levels, each is a 4D-tensor, the channels number is
  194. num_anchors * 4.
  195. """
  196. cls_scores = []
  197. bbox_preds = []
  198. for feat, reg_conv, cls_conv in zip(feats, self.reg_convs,
  199. self.cls_convs):
  200. cls_scores.append(cls_conv(feat))
  201. bbox_preds.append(reg_conv(feat))
  202. return cls_scores, bbox_preds
  203. def loss_single(self, cls_score, bbox_pred, anchor, labels, label_weights,
  204. bbox_targets, bbox_weights, num_total_samples):
  205. """Compute loss of a single image.
  206. Args:
  207. cls_score (Tensor): Box scores for eachimage
  208. Has shape (num_total_anchors, num_classes).
  209. bbox_pred (Tensor): Box energies / deltas for each image
  210. level with shape (num_total_anchors, 4).
  211. anchors (Tensor): Box reference for each scale level with shape
  212. (num_total_anchors, 4).
  213. labels (Tensor): Labels of each anchors with shape
  214. (num_total_anchors,).
  215. label_weights (Tensor): Label weights of each anchor with shape
  216. (num_total_anchors,)
  217. bbox_targets (Tensor): BBox regression targets of each anchor
  218. weight shape (num_total_anchors, 4).
  219. bbox_weights (Tensor): BBox regression loss weights of each anchor
  220. with shape (num_total_anchors, 4).
  221. num_total_samples (int): If sampling, num total samples equal to
  222. the number of total anchors; Otherwise, it is the number of
  223. positive anchors.
  224. Returns:
  225. dict[str, Tensor]: A dictionary of loss components.
  226. """
  227. loss_cls_all = F.cross_entropy(
  228. cls_score, labels, reduction='none') * label_weights
  229. # FG cat_id: [0, num_classes -1], BG cat_id: num_classes
  230. pos_inds = ((labels >= 0) & (labels < self.num_classes)).nonzero(
  231. as_tuple=False).reshape(-1)
  232. neg_inds = (labels == self.num_classes).nonzero(
  233. as_tuple=False).view(-1)
  234. num_pos_samples = pos_inds.size(0)
  235. num_neg_samples = self.train_cfg.neg_pos_ratio * num_pos_samples
  236. if num_neg_samples > neg_inds.size(0):
  237. num_neg_samples = neg_inds.size(0)
  238. topk_loss_cls_neg, _ = loss_cls_all[neg_inds].topk(num_neg_samples)
  239. loss_cls_pos = loss_cls_all[pos_inds].sum()
  240. loss_cls_neg = topk_loss_cls_neg.sum()
  241. loss_cls = (loss_cls_pos + loss_cls_neg) / num_total_samples
  242. if self.reg_decoded_bbox:
  243. # When the regression loss (e.g. `IouLoss`, `GIouLoss`)
  244. # is applied directly on the decoded bounding boxes, it
  245. # decodes the already encoded coordinates to absolute format.
  246. bbox_pred = self.bbox_coder.decode(anchor, bbox_pred)
  247. loss_bbox = smooth_l1_loss(
  248. bbox_pred,
  249. bbox_targets,
  250. bbox_weights,
  251. beta=self.train_cfg.smoothl1_beta,
  252. avg_factor=num_total_samples)
  253. return loss_cls[None], loss_bbox
  254. @force_fp32(apply_to=('cls_scores', 'bbox_preds'))
  255. def loss(self,
  256. cls_scores,
  257. bbox_preds,
  258. gt_bboxes,
  259. gt_labels,
  260. img_metas,
  261. gt_bboxes_ignore=None):
  262. """Compute losses of the head.
  263. Args:
  264. cls_scores (list[Tensor]): Box scores for each scale level
  265. Has shape (N, num_anchors * num_classes, H, W)
  266. bbox_preds (list[Tensor]): Box energies / deltas for each scale
  267. level with shape (N, num_anchors * 4, H, W)
  268. gt_bboxes (list[Tensor]): each item are the truth boxes for each
  269. image in [tl_x, tl_y, br_x, br_y] format.
  270. gt_labels (list[Tensor]): class indices corresponding to each box
  271. img_metas (list[dict]): Meta information of each image, e.g.,
  272. image size, scaling factor, etc.
  273. gt_bboxes_ignore (None | list[Tensor]): specify which bounding
  274. boxes can be ignored when computing the loss.
  275. Returns:
  276. dict[str, Tensor]: A dictionary of loss components.
  277. """
  278. featmap_sizes = [featmap.size()[-2:] for featmap in cls_scores]
  279. assert len(featmap_sizes) == self.prior_generator.num_levels
  280. device = cls_scores[0].device
  281. anchor_list, valid_flag_list = self.get_anchors(
  282. featmap_sizes, img_metas, device=device)
  283. cls_reg_targets = self.get_targets(
  284. anchor_list,
  285. valid_flag_list,
  286. gt_bboxes,
  287. img_metas,
  288. gt_bboxes_ignore_list=gt_bboxes_ignore,
  289. gt_labels_list=gt_labels,
  290. label_channels=1,
  291. unmap_outputs=False)
  292. if cls_reg_targets is None:
  293. return None
  294. (labels_list, label_weights_list, bbox_targets_list, bbox_weights_list,
  295. num_total_pos, num_total_neg) = cls_reg_targets
  296. num_images = len(img_metas)
  297. all_cls_scores = torch.cat([
  298. s.permute(0, 2, 3, 1).reshape(
  299. num_images, -1, self.cls_out_channels) for s in cls_scores
  300. ], 1)
  301. all_labels = torch.cat(labels_list, -1).view(num_images, -1)
  302. all_label_weights = torch.cat(label_weights_list,
  303. -1).view(num_images, -1)
  304. all_bbox_preds = torch.cat([
  305. b.permute(0, 2, 3, 1).reshape(num_images, -1, 4)
  306. for b in bbox_preds
  307. ], -2)
  308. all_bbox_targets = torch.cat(bbox_targets_list,
  309. -2).view(num_images, -1, 4)
  310. all_bbox_weights = torch.cat(bbox_weights_list,
  311. -2).view(num_images, -1, 4)
  312. # concat all level anchors to a single tensor
  313. all_anchors = []
  314. for i in range(num_images):
  315. all_anchors.append(torch.cat(anchor_list[i]))
  316. losses_cls, losses_bbox = multi_apply(
  317. self.loss_single,
  318. all_cls_scores,
  319. all_bbox_preds,
  320. all_anchors,
  321. all_labels,
  322. all_label_weights,
  323. all_bbox_targets,
  324. all_bbox_weights,
  325. num_total_samples=num_total_pos)
  326. return dict(loss_cls=losses_cls, loss_bbox=losses_bbox)

No Description

Contributors (3)