You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

yolof_head.py 17 kB

2 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416
  1. # Copyright (c) OpenMMLab. All rights reserved.
  2. import torch
  3. import torch.nn as nn
  4. from mmcv.cnn import (ConvModule, bias_init_with_prob, constant_init, is_norm,
  5. normal_init)
  6. from mmcv.runner import force_fp32
  7. from mmdet.core import anchor_inside_flags, multi_apply, reduce_mean, unmap
  8. from ..builder import HEADS
  9. from .anchor_head import AnchorHead
  10. INF = 1e8
  11. def levels_to_images(mlvl_tensor):
  12. """Concat multi-level feature maps by image.
  13. [feature_level0, feature_level1...] -> [feature_image0, feature_image1...]
  14. Convert the shape of each element in mlvl_tensor from (N, C, H, W) to
  15. (N, H*W , C), then split the element to N elements with shape (H*W, C), and
  16. concat elements in same image of all level along first dimension.
  17. Args:
  18. mlvl_tensor (list[torch.Tensor]): list of Tensor which collect from
  19. corresponding level. Each element is of shape (N, C, H, W)
  20. Returns:
  21. list[torch.Tensor]: A list that contains N tensors and each tensor is
  22. of shape (num_elements, C)
  23. """
  24. batch_size = mlvl_tensor[0].size(0)
  25. batch_list = [[] for _ in range(batch_size)]
  26. channels = mlvl_tensor[0].size(1)
  27. for t in mlvl_tensor:
  28. t = t.permute(0, 2, 3, 1)
  29. t = t.view(batch_size, -1, channels).contiguous()
  30. for img in range(batch_size):
  31. batch_list[img].append(t[img])
  32. return [torch.cat(item, 0) for item in batch_list]
  33. @HEADS.register_module()
  34. class YOLOFHead(AnchorHead):
  35. """YOLOFHead Paper link: https://arxiv.org/abs/2103.09460.
  36. Args:
  37. num_classes (int): The number of object classes (w/o background)
  38. in_channels (List[int]): The number of input channels per scale.
  39. cls_num_convs (int): The number of convolutions of cls branch.
  40. Default 2.
  41. reg_num_convs (int): The number of convolutions of reg branch.
  42. Default 4.
  43. norm_cfg (dict): Dictionary to construct and config norm layer.
  44. """
  45. def __init__(self,
  46. num_classes,
  47. in_channels,
  48. num_cls_convs=2,
  49. num_reg_convs=4,
  50. norm_cfg=dict(type='BN', requires_grad=True),
  51. **kwargs):
  52. self.num_cls_convs = num_cls_convs
  53. self.num_reg_convs = num_reg_convs
  54. self.norm_cfg = norm_cfg
  55. super(YOLOFHead, self).__init__(num_classes, in_channels, **kwargs)
  56. def _init_layers(self):
  57. cls_subnet = []
  58. bbox_subnet = []
  59. for i in range(self.num_cls_convs):
  60. cls_subnet.append(
  61. ConvModule(
  62. self.in_channels,
  63. self.in_channels,
  64. kernel_size=3,
  65. padding=1,
  66. norm_cfg=self.norm_cfg))
  67. for i in range(self.num_reg_convs):
  68. bbox_subnet.append(
  69. ConvModule(
  70. self.in_channels,
  71. self.in_channels,
  72. kernel_size=3,
  73. padding=1,
  74. norm_cfg=self.norm_cfg))
  75. self.cls_subnet = nn.Sequential(*cls_subnet)
  76. self.bbox_subnet = nn.Sequential(*bbox_subnet)
  77. self.cls_score = nn.Conv2d(
  78. self.in_channels,
  79. self.num_base_priors * self.num_classes,
  80. kernel_size=3,
  81. stride=1,
  82. padding=1)
  83. self.bbox_pred = nn.Conv2d(
  84. self.in_channels,
  85. self.num_base_priors * 4,
  86. kernel_size=3,
  87. stride=1,
  88. padding=1)
  89. self.object_pred = nn.Conv2d(
  90. self.in_channels,
  91. self.num_base_priors,
  92. kernel_size=3,
  93. stride=1,
  94. padding=1)
  95. def init_weights(self):
  96. for m in self.modules():
  97. if isinstance(m, nn.Conv2d):
  98. normal_init(m, mean=0, std=0.01)
  99. if is_norm(m):
  100. constant_init(m, 1)
  101. # Use prior in model initialization to improve stability
  102. bias_cls = bias_init_with_prob(0.01)
  103. torch.nn.init.constant_(self.cls_score.bias, bias_cls)
  104. def forward_single(self, feature):
  105. cls_score = self.cls_score(self.cls_subnet(feature))
  106. N, _, H, W = cls_score.shape
  107. cls_score = cls_score.view(N, -1, self.num_classes, H, W)
  108. reg_feat = self.bbox_subnet(feature)
  109. bbox_reg = self.bbox_pred(reg_feat)
  110. objectness = self.object_pred(reg_feat)
  111. # implicit objectness
  112. objectness = objectness.view(N, -1, 1, H, W)
  113. normalized_cls_score = cls_score + objectness - torch.log(
  114. 1. + torch.clamp(cls_score.exp(), max=INF) +
  115. torch.clamp(objectness.exp(), max=INF))
  116. normalized_cls_score = normalized_cls_score.view(N, -1, H, W)
  117. return normalized_cls_score, bbox_reg
  118. @force_fp32(apply_to=('cls_scores', 'bbox_preds'))
  119. def loss(self,
  120. cls_scores,
  121. bbox_preds,
  122. gt_bboxes,
  123. gt_labels,
  124. img_metas,
  125. gt_bboxes_ignore=None):
  126. """Compute losses of the head.
  127. Args:
  128. cls_scores (list[Tensor]): Box scores for each scale level
  129. Has shape (batch, num_anchors * num_classes, h, w)
  130. bbox_preds (list[Tensor]): Box energies / deltas for each scale
  131. level with shape (batch, num_anchors * 4, h, w)
  132. gt_bboxes (list[Tensor]): Ground truth bboxes for each image with
  133. shape (num_gts, 4) in [tl_x, tl_y, br_x, br_y] format.
  134. gt_labels (list[Tensor]): class indices corresponding to each box
  135. img_metas (list[dict]): Meta information of each image, e.g.,
  136. image size, scaling factor, etc.
  137. gt_bboxes_ignore (None | list[Tensor]): specify which bounding
  138. boxes can be ignored when computing the loss. Default: None
  139. Returns:
  140. dict[str, Tensor]: A dictionary of loss components.
  141. """
  142. assert len(cls_scores) == 1
  143. assert self.prior_generator.num_levels == 1
  144. device = cls_scores[0].device
  145. featmap_sizes = [featmap.size()[-2:] for featmap in cls_scores]
  146. anchor_list, valid_flag_list = self.get_anchors(
  147. featmap_sizes, img_metas, device=device)
  148. # The output level is always 1
  149. anchor_list = [anchors[0] for anchors in anchor_list]
  150. valid_flag_list = [valid_flags[0] for valid_flags in valid_flag_list]
  151. cls_scores_list = levels_to_images(cls_scores)
  152. bbox_preds_list = levels_to_images(bbox_preds)
  153. label_channels = self.cls_out_channels if self.use_sigmoid_cls else 1
  154. cls_reg_targets = self.get_targets(
  155. cls_scores_list,
  156. bbox_preds_list,
  157. anchor_list,
  158. valid_flag_list,
  159. gt_bboxes,
  160. img_metas,
  161. gt_bboxes_ignore_list=gt_bboxes_ignore,
  162. gt_labels_list=gt_labels,
  163. label_channels=label_channels)
  164. if cls_reg_targets is None:
  165. return None
  166. (batch_labels, batch_label_weights, num_total_pos, num_total_neg,
  167. batch_bbox_weights, batch_pos_predicted_boxes,
  168. batch_target_boxes) = cls_reg_targets
  169. flatten_labels = batch_labels.reshape(-1)
  170. batch_label_weights = batch_label_weights.reshape(-1)
  171. cls_score = cls_scores[0].permute(0, 2, 3,
  172. 1).reshape(-1, self.cls_out_channels)
  173. num_total_samples = (num_total_pos +
  174. num_total_neg) if self.sampling else num_total_pos
  175. num_total_samples = reduce_mean(
  176. cls_score.new_tensor(num_total_samples)).clamp_(1.0).item()
  177. # classification loss
  178. loss_cls = self.loss_cls(
  179. cls_score,
  180. flatten_labels,
  181. batch_label_weights,
  182. avg_factor=num_total_samples)
  183. # regression loss
  184. if batch_pos_predicted_boxes.shape[0] == 0:
  185. # no pos sample
  186. loss_bbox = batch_pos_predicted_boxes.sum() * 0
  187. else:
  188. loss_bbox = self.loss_bbox(
  189. batch_pos_predicted_boxes,
  190. batch_target_boxes,
  191. batch_bbox_weights.float(),
  192. avg_factor=num_total_samples)
  193. return dict(loss_cls=loss_cls, loss_bbox=loss_bbox)
  194. def get_targets(self,
  195. cls_scores_list,
  196. bbox_preds_list,
  197. anchor_list,
  198. valid_flag_list,
  199. gt_bboxes_list,
  200. img_metas,
  201. gt_bboxes_ignore_list=None,
  202. gt_labels_list=None,
  203. label_channels=1,
  204. unmap_outputs=True):
  205. """Compute regression and classification targets for anchors in
  206. multiple images.
  207. Args:
  208. cls_scores_list (list[Tensor]): Classification scores of
  209. each image. each is a 4D-tensor, the shape is
  210. (h * w, num_anchors * num_classes).
  211. bbox_preds_list (list[Tensor]): Bbox preds of each image.
  212. each is a 4D-tensor, the shape is (h * w, num_anchors * 4).
  213. anchor_list (list[Tensor]): Anchors of each image. Each element of
  214. is a tensor of shape (h * w * num_anchors, 4).
  215. valid_flag_list (list[Tensor]): Valid flags of each image. Each
  216. element of is a tensor of shape (h * w * num_anchors, )
  217. gt_bboxes_list (list[Tensor]): Ground truth bboxes of each image.
  218. img_metas (list[dict]): Meta info of each image.
  219. gt_bboxes_ignore_list (list[Tensor]): Ground truth bboxes to be
  220. ignored.
  221. gt_labels_list (list[Tensor]): Ground truth labels of each box.
  222. label_channels (int): Channel of label.
  223. unmap_outputs (bool): Whether to map outputs back to the original
  224. set of anchors.
  225. Returns:
  226. tuple: Usually returns a tuple containing learning targets.
  227. - batch_labels (Tensor): Label of all images. Each element \
  228. of is a tensor of shape (batch, h * w * num_anchors)
  229. - batch_label_weights (Tensor): Label weights of all images \
  230. of is a tensor of shape (batch, h * w * num_anchors)
  231. - num_total_pos (int): Number of positive samples in all \
  232. images.
  233. - num_total_neg (int): Number of negative samples in all \
  234. images.
  235. additional_returns: This function enables user-defined returns from
  236. `self._get_targets_single`. These returns are currently refined
  237. to properties at each feature map (i.e. having HxW dimension).
  238. The results will be concatenated after the end
  239. """
  240. num_imgs = len(img_metas)
  241. assert len(anchor_list) == len(valid_flag_list) == num_imgs
  242. # compute targets for each image
  243. if gt_bboxes_ignore_list is None:
  244. gt_bboxes_ignore_list = [None for _ in range(num_imgs)]
  245. if gt_labels_list is None:
  246. gt_labels_list = [None for _ in range(num_imgs)]
  247. results = multi_apply(
  248. self._get_targets_single,
  249. bbox_preds_list,
  250. anchor_list,
  251. valid_flag_list,
  252. gt_bboxes_list,
  253. gt_bboxes_ignore_list,
  254. gt_labels_list,
  255. img_metas,
  256. label_channels=label_channels,
  257. unmap_outputs=unmap_outputs)
  258. (all_labels, all_label_weights, pos_inds_list, neg_inds_list,
  259. sampling_results_list) = results[:5]
  260. rest_results = list(results[5:]) # user-added return values
  261. # no valid anchors
  262. if any([labels is None for labels in all_labels]):
  263. return None
  264. # sampled anchors of all images
  265. num_total_pos = sum([max(inds.numel(), 1) for inds in pos_inds_list])
  266. num_total_neg = sum([max(inds.numel(), 1) for inds in neg_inds_list])
  267. batch_labels = torch.stack(all_labels, 0)
  268. batch_label_weights = torch.stack(all_label_weights, 0)
  269. res = (batch_labels, batch_label_weights, num_total_pos, num_total_neg)
  270. for i, rests in enumerate(rest_results): # user-added return values
  271. rest_results[i] = torch.cat(rests, 0)
  272. return res + tuple(rest_results)
  273. def _get_targets_single(self,
  274. bbox_preds,
  275. flat_anchors,
  276. valid_flags,
  277. gt_bboxes,
  278. gt_bboxes_ignore,
  279. gt_labels,
  280. img_meta,
  281. label_channels=1,
  282. unmap_outputs=True):
  283. """Compute regression and classification targets for anchors in a
  284. single image.
  285. Args:
  286. bbox_preds (Tensor): Bbox prediction of the image, which
  287. shape is (h * w ,4)
  288. flat_anchors (Tensor): Anchors of the image, which shape is
  289. (h * w * num_anchors ,4)
  290. valid_flags (Tensor): Valid flags of the image, which shape is
  291. (h * w * num_anchors,).
  292. gt_bboxes (Tensor): Ground truth bboxes of the image,
  293. shape (num_gts, 4).
  294. gt_bboxes_ignore (Tensor): Ground truth bboxes to be
  295. ignored, shape (num_ignored_gts, 4).
  296. img_meta (dict): Meta info of the image.
  297. gt_labels (Tensor): Ground truth labels of each box,
  298. shape (num_gts,).
  299. label_channels (int): Channel of label.
  300. unmap_outputs (bool): Whether to map outputs back to the original
  301. set of anchors.
  302. Returns:
  303. tuple:
  304. labels (Tensor): Labels of image, which shape is
  305. (h * w * num_anchors, ).
  306. label_weights (Tensor): Label weights of image, which shape is
  307. (h * w * num_anchors, ).
  308. pos_inds (Tensor): Pos index of image.
  309. neg_inds (Tensor): Neg index of image.
  310. sampling_result (obj:`SamplingResult`): Sampling result.
  311. pos_bbox_weights (Tensor): The Weight of using to calculate
  312. the bbox branch loss, which shape is (num, ).
  313. pos_predicted_boxes (Tensor): boxes predicted value of
  314. using to calculate the bbox branch loss, which shape is
  315. (num, 4).
  316. pos_target_boxes (Tensor): boxes target value of
  317. using to calculate the bbox branch loss, which shape is
  318. (num, 4).
  319. """
  320. inside_flags = anchor_inside_flags(flat_anchors, valid_flags,
  321. img_meta['img_shape'][:2],
  322. self.train_cfg.allowed_border)
  323. if not inside_flags.any():
  324. return (None, ) * 8
  325. # assign gt and sample anchors
  326. anchors = flat_anchors[inside_flags, :]
  327. bbox_preds = bbox_preds.reshape(-1, 4)
  328. bbox_preds = bbox_preds[inside_flags, :]
  329. # decoded bbox
  330. decoder_bbox_preds = self.bbox_coder.decode(anchors, bbox_preds)
  331. assign_result = self.assigner.assign(
  332. decoder_bbox_preds, anchors, gt_bboxes, gt_bboxes_ignore,
  333. None if self.sampling else gt_labels)
  334. pos_bbox_weights = assign_result.get_extra_property('pos_idx')
  335. pos_predicted_boxes = assign_result.get_extra_property(
  336. 'pos_predicted_boxes')
  337. pos_target_boxes = assign_result.get_extra_property('target_boxes')
  338. sampling_result = self.sampler.sample(assign_result, anchors,
  339. gt_bboxes)
  340. num_valid_anchors = anchors.shape[0]
  341. labels = anchors.new_full((num_valid_anchors, ),
  342. self.num_classes,
  343. dtype=torch.long)
  344. label_weights = anchors.new_zeros(num_valid_anchors, dtype=torch.float)
  345. pos_inds = sampling_result.pos_inds
  346. neg_inds = sampling_result.neg_inds
  347. if len(pos_inds) > 0:
  348. if gt_labels is None:
  349. # Only rpn gives gt_labels as None
  350. # Foreground is the first class since v2.5.0
  351. labels[pos_inds] = 0
  352. else:
  353. labels[pos_inds] = gt_labels[
  354. sampling_result.pos_assigned_gt_inds]
  355. if self.train_cfg.pos_weight <= 0:
  356. label_weights[pos_inds] = 1.0
  357. else:
  358. label_weights[pos_inds] = self.train_cfg.pos_weight
  359. if len(neg_inds) > 0:
  360. label_weights[neg_inds] = 1.0
  361. # map up to original set of anchors
  362. if unmap_outputs:
  363. num_total_anchors = flat_anchors.size(0)
  364. labels = unmap(
  365. labels, num_total_anchors, inside_flags,
  366. fill=self.num_classes) # fill bg label
  367. label_weights = unmap(label_weights, num_total_anchors,
  368. inside_flags)
  369. return (labels, label_weights, pos_inds, neg_inds, sampling_result,
  370. pos_bbox_weights, pos_predicted_boxes, pos_target_boxes)

No Description

Contributors (3)