You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

yolo_head.py 26 kB

2 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619
  1. # Copyright (c) OpenMMLab. All rights reserved.
  2. # Copyright (c) 2019 Western Digital Corporation or its affiliates.
  3. import warnings
  4. import torch
  5. import torch.nn as nn
  6. import torch.nn.functional as F
  7. from mmcv.cnn import (ConvModule, bias_init_with_prob, constant_init, is_norm,
  8. normal_init)
  9. from mmcv.runner import force_fp32
  10. from mmdet.core import (build_assigner, build_bbox_coder,
  11. build_prior_generator, build_sampler, images_to_levels,
  12. multi_apply, multiclass_nms)
  13. from ..builder import HEADS, build_loss
  14. from .base_dense_head import BaseDenseHead
  15. from .dense_test_mixins import BBoxTestMixin
  16. @HEADS.register_module()
  17. class YOLOV3Head(BaseDenseHead, BBoxTestMixin):
  18. """YOLOV3Head Paper link: https://arxiv.org/abs/1804.02767.
  19. Args:
  20. num_classes (int): The number of object classes (w/o background)
  21. in_channels (List[int]): Number of input channels per scale.
  22. out_channels (List[int]): The number of output channels per scale
  23. before the final 1x1 layer. Default: (1024, 512, 256).
  24. anchor_generator (dict): Config dict for anchor generator
  25. bbox_coder (dict): Config of bounding box coder.
  26. featmap_strides (List[int]): The stride of each scale.
  27. Should be in descending order. Default: (32, 16, 8).
  28. one_hot_smoother (float): Set a non-zero value to enable label-smooth
  29. Default: 0.
  30. conv_cfg (dict): Config dict for convolution layer. Default: None.
  31. norm_cfg (dict): Dictionary to construct and config norm layer.
  32. Default: dict(type='BN', requires_grad=True)
  33. act_cfg (dict): Config dict for activation layer.
  34. Default: dict(type='LeakyReLU', negative_slope=0.1).
  35. loss_cls (dict): Config of classification loss.
  36. loss_conf (dict): Config of confidence loss.
  37. loss_xy (dict): Config of xy coordinate loss.
  38. loss_wh (dict): Config of wh coordinate loss.
  39. train_cfg (dict): Training config of YOLOV3 head. Default: None.
  40. test_cfg (dict): Testing config of YOLOV3 head. Default: None.
  41. init_cfg (dict or list[dict], optional): Initialization config dict.
  42. """
  43. def __init__(self,
  44. num_classes,
  45. in_channels,
  46. out_channels=(1024, 512, 256),
  47. anchor_generator=dict(
  48. type='YOLOAnchorGenerator',
  49. base_sizes=[[(116, 90), (156, 198), (373, 326)],
  50. [(30, 61), (62, 45), (59, 119)],
  51. [(10, 13), (16, 30), (33, 23)]],
  52. strides=[32, 16, 8]),
  53. bbox_coder=dict(type='YOLOBBoxCoder'),
  54. featmap_strides=[32, 16, 8],
  55. one_hot_smoother=0.,
  56. conv_cfg=None,
  57. norm_cfg=dict(type='BN', requires_grad=True),
  58. act_cfg=dict(type='LeakyReLU', negative_slope=0.1),
  59. loss_cls=dict(
  60. type='CrossEntropyLoss',
  61. use_sigmoid=True,
  62. loss_weight=1.0),
  63. loss_conf=dict(
  64. type='CrossEntropyLoss',
  65. use_sigmoid=True,
  66. loss_weight=1.0),
  67. loss_xy=dict(
  68. type='CrossEntropyLoss',
  69. use_sigmoid=True,
  70. loss_weight=1.0),
  71. loss_wh=dict(type='MSELoss', loss_weight=1.0),
  72. train_cfg=None,
  73. test_cfg=None,
  74. init_cfg=dict(
  75. type='Normal', std=0.01,
  76. override=dict(name='convs_pred'))):
  77. super(YOLOV3Head, self).__init__(init_cfg)
  78. # Check params
  79. assert (len(in_channels) == len(out_channels) == len(featmap_strides))
  80. self.num_classes = num_classes
  81. self.in_channels = in_channels
  82. self.out_channels = out_channels
  83. self.featmap_strides = featmap_strides
  84. self.train_cfg = train_cfg
  85. self.test_cfg = test_cfg
  86. if self.train_cfg:
  87. self.assigner = build_assigner(self.train_cfg.assigner)
  88. if hasattr(self.train_cfg, 'sampler'):
  89. sampler_cfg = self.train_cfg.sampler
  90. else:
  91. sampler_cfg = dict(type='PseudoSampler')
  92. self.sampler = build_sampler(sampler_cfg, context=self)
  93. self.fp16_enabled = False
  94. self.one_hot_smoother = one_hot_smoother
  95. self.conv_cfg = conv_cfg
  96. self.norm_cfg = norm_cfg
  97. self.act_cfg = act_cfg
  98. self.bbox_coder = build_bbox_coder(bbox_coder)
  99. self.prior_generator = build_prior_generator(anchor_generator)
  100. self.loss_cls = build_loss(loss_cls)
  101. self.loss_conf = build_loss(loss_conf)
  102. self.loss_xy = build_loss(loss_xy)
  103. self.loss_wh = build_loss(loss_wh)
  104. self.num_base_priors = self.prior_generator.num_base_priors[0]
  105. assert len(
  106. self.prior_generator.num_base_priors) == len(featmap_strides)
  107. self._init_layers()
  108. @property
  109. def anchor_generator(self):
  110. warnings.warn('DeprecationWarning: `anchor_generator` is deprecated, '
  111. 'please use "prior_generator" instead')
  112. return self.prior_generator
  113. @property
  114. def num_anchors(self):
  115. """
  116. Returns:
  117. int: Number of anchors on each point of feature map.
  118. """
  119. warnings.warn('DeprecationWarning: `num_anchors` is deprecated, '
  120. 'please use "num_base_priors" instead')
  121. return self.num_base_priors
  122. @property
  123. def num_levels(self):
  124. return len(self.featmap_strides)
  125. @property
  126. def num_attrib(self):
  127. """int: number of attributes in pred_map, bboxes (4) +
  128. objectness (1) + num_classes"""
  129. return 5 + self.num_classes
  130. def _init_layers(self):
  131. self.convs_bridge = nn.ModuleList()
  132. self.convs_pred = nn.ModuleList()
  133. for i in range(self.num_levels):
  134. conv_bridge = ConvModule(
  135. self.in_channels[i],
  136. self.out_channels[i],
  137. 3,
  138. padding=1,
  139. conv_cfg=self.conv_cfg,
  140. norm_cfg=self.norm_cfg,
  141. act_cfg=self.act_cfg)
  142. conv_pred = nn.Conv2d(self.out_channels[i],
  143. self.num_base_priors * self.num_attrib, 1)
  144. self.convs_bridge.append(conv_bridge)
  145. self.convs_pred.append(conv_pred)
  146. def init_weights(self):
  147. for m in self.modules():
  148. if isinstance(m, nn.Conv2d):
  149. normal_init(m, mean=0, std=0.01)
  150. if is_norm(m):
  151. constant_init(m, 1)
  152. # Use prior in model initialization to improve stability
  153. for conv_pred, stride in zip(self.convs_pred, self.featmap_strides):
  154. bias = conv_pred.bias.reshape(self.num_base_priors, -1)
  155. # init objectness with prior of 8 objects per feature map
  156. # refer to https://github.com/ultralytics/yolov3
  157. nn.init.constant_(bias.data[:, 4],
  158. bias_init_with_prob(8 / (608 / stride)**2))
  159. nn.init.constant_(bias.data[:, 5:], bias_init_with_prob(0.01))
  160. def forward(self, feats):
  161. """Forward features from the upstream network.
  162. Args:
  163. feats (tuple[Tensor]): Features from the upstream network, each is
  164. a 4D-tensor.
  165. Returns:
  166. tuple[Tensor]: A tuple of multi-level predication map, each is a
  167. 4D-tensor of shape (batch_size, 5+num_classes, height, width).
  168. """
  169. assert len(feats) == self.num_levels
  170. pred_maps = []
  171. for i in range(self.num_levels):
  172. x = feats[i]
  173. x = self.convs_bridge[i](x)
  174. pred_map = self.convs_pred[i](x)
  175. pred_maps.append(pred_map)
  176. return tuple(pred_maps),
  177. @force_fp32(apply_to=('pred_maps', ))
  178. def get_bboxes(self,
  179. pred_maps,
  180. img_metas,
  181. cfg=None,
  182. rescale=False,
  183. with_nms=True):
  184. """Transform network output for a batch into bbox predictions. It has
  185. been accelerated since PR #5991.
  186. Args:
  187. pred_maps (list[Tensor]): Raw predictions for a batch of images.
  188. img_metas (list[dict]): Meta information of each image, e.g.,
  189. image size, scaling factor, etc.
  190. cfg (mmcv.Config | None): Test / postprocessing configuration,
  191. if None, test_cfg would be used. Default: None.
  192. rescale (bool): If True, return boxes in original image space.
  193. Default: False.
  194. with_nms (bool): If True, do nms before return boxes.
  195. Default: True.
  196. Returns:
  197. list[tuple[Tensor, Tensor]]: Each item in result_list is 2-tuple.
  198. The first item is an (n, 5) tensor, where 5 represent
  199. (tl_x, tl_y, br_x, br_y, score) and the score between 0 and 1.
  200. The shape of the second tensor in the tuple is (n,), and
  201. each element represents the class label of the corresponding
  202. box.
  203. """
  204. assert len(pred_maps) == self.num_levels
  205. cfg = self.test_cfg if cfg is None else cfg
  206. scale_factors = [img_meta['scale_factor'] for img_meta in img_metas]
  207. num_imgs = len(img_metas)
  208. featmap_sizes = [pred_map.shape[-2:] for pred_map in pred_maps]
  209. mlvl_anchors = self.prior_generator.grid_priors(
  210. featmap_sizes, device=pred_maps[0].device)
  211. flatten_preds = []
  212. flatten_strides = []
  213. for pred, stride in zip(pred_maps, self.featmap_strides):
  214. pred = pred.permute(0, 2, 3, 1).reshape(num_imgs, -1,
  215. self.num_attrib)
  216. pred[..., :2].sigmoid_()
  217. flatten_preds.append(pred)
  218. flatten_strides.append(
  219. pred.new_tensor(stride).expand(pred.size(1)))
  220. flatten_preds = torch.cat(flatten_preds, dim=1)
  221. flatten_bbox_preds = flatten_preds[..., :4]
  222. flatten_objectness = flatten_preds[..., 4].sigmoid()
  223. flatten_cls_scores = flatten_preds[..., 5:].sigmoid()
  224. flatten_anchors = torch.cat(mlvl_anchors)
  225. flatten_strides = torch.cat(flatten_strides)
  226. flatten_bboxes = self.bbox_coder.decode(flatten_anchors,
  227. flatten_bbox_preds,
  228. flatten_strides.unsqueeze(-1))
  229. if with_nms and (flatten_objectness.size(0) == 0):
  230. return torch.zeros((0, 5)), torch.zeros((0, ))
  231. if rescale:
  232. flatten_bboxes /= flatten_bboxes.new_tensor(
  233. scale_factors).unsqueeze(1)
  234. padding = flatten_bboxes.new_zeros(num_imgs, flatten_bboxes.shape[1],
  235. 1)
  236. flatten_cls_scores = torch.cat([flatten_cls_scores, padding], dim=-1)
  237. det_results = []
  238. for (bboxes, scores, objectness) in zip(flatten_bboxes,
  239. flatten_cls_scores,
  240. flatten_objectness):
  241. # Filtering out all predictions with conf < conf_thr
  242. conf_thr = cfg.get('conf_thr', -1)
  243. if conf_thr > 0:
  244. conf_inds = objectness >= conf_thr
  245. bboxes = bboxes[conf_inds, :]
  246. scores = scores[conf_inds, :]
  247. objectness = objectness[conf_inds]
  248. det_bboxes, det_labels = multiclass_nms(
  249. bboxes,
  250. scores,
  251. cfg.score_thr,
  252. cfg.nms,
  253. cfg.max_per_img,
  254. score_factors=objectness)
  255. det_results.append(tuple([det_bboxes, det_labels]))
  256. return det_results
  257. @force_fp32(apply_to=('pred_maps', ))
  258. def loss(self,
  259. pred_maps,
  260. gt_bboxes,
  261. gt_labels,
  262. img_metas,
  263. gt_bboxes_ignore=None):
  264. """Compute loss of the head.
  265. Args:
  266. pred_maps (list[Tensor]): Prediction map for each scale level,
  267. shape (N, num_anchors * num_attrib, H, W)
  268. gt_bboxes (list[Tensor]): Ground truth bboxes for each image with
  269. shape (num_gts, 4) in [tl_x, tl_y, br_x, br_y] format.
  270. gt_labels (list[Tensor]): class indices corresponding to each box
  271. img_metas (list[dict]): Meta information of each image, e.g.,
  272. image size, scaling factor, etc.
  273. gt_bboxes_ignore (None | list[Tensor]): specify which bounding
  274. boxes can be ignored when computing the loss.
  275. Returns:
  276. dict[str, Tensor]: A dictionary of loss components.
  277. """
  278. num_imgs = len(img_metas)
  279. device = pred_maps[0][0].device
  280. featmap_sizes = [
  281. pred_maps[i].shape[-2:] for i in range(self.num_levels)
  282. ]
  283. mlvl_anchors = self.prior_generator.grid_priors(
  284. featmap_sizes, device=device)
  285. anchor_list = [mlvl_anchors for _ in range(num_imgs)]
  286. responsible_flag_list = []
  287. for img_id in range(len(img_metas)):
  288. responsible_flag_list.append(
  289. self.prior_generator.responsible_flags(featmap_sizes,
  290. gt_bboxes[img_id],
  291. device))
  292. target_maps_list, neg_maps_list = self.get_targets(
  293. anchor_list, responsible_flag_list, gt_bboxes, gt_labels)
  294. losses_cls, losses_conf, losses_xy, losses_wh = multi_apply(
  295. self.loss_single, pred_maps, target_maps_list, neg_maps_list)
  296. return dict(
  297. loss_cls=losses_cls,
  298. loss_conf=losses_conf,
  299. loss_xy=losses_xy,
  300. loss_wh=losses_wh)
  301. def loss_single(self, pred_map, target_map, neg_map):
  302. """Compute loss of a single image from a batch.
  303. Args:
  304. pred_map (Tensor): Raw predictions for a single level.
  305. target_map (Tensor): The Ground-Truth target for a single level.
  306. neg_map (Tensor): The negative masks for a single level.
  307. Returns:
  308. tuple:
  309. loss_cls (Tensor): Classification loss.
  310. loss_conf (Tensor): Confidence loss.
  311. loss_xy (Tensor): Regression loss of x, y coordinate.
  312. loss_wh (Tensor): Regression loss of w, h coordinate.
  313. """
  314. num_imgs = len(pred_map)
  315. pred_map = pred_map.permute(0, 2, 3,
  316. 1).reshape(num_imgs, -1, self.num_attrib)
  317. neg_mask = neg_map.float()
  318. pos_mask = target_map[..., 4]
  319. pos_and_neg_mask = neg_mask + pos_mask
  320. pos_mask = pos_mask.unsqueeze(dim=-1)
  321. if torch.max(pos_and_neg_mask) > 1.:
  322. warnings.warn('There is overlap between pos and neg sample.')
  323. pos_and_neg_mask = pos_and_neg_mask.clamp(min=0., max=1.)
  324. pred_xy = pred_map[..., :2]
  325. pred_wh = pred_map[..., 2:4]
  326. pred_conf = pred_map[..., 4]
  327. pred_label = pred_map[..., 5:]
  328. target_xy = target_map[..., :2]
  329. target_wh = target_map[..., 2:4]
  330. target_conf = target_map[..., 4]
  331. target_label = target_map[..., 5:]
  332. loss_cls = self.loss_cls(pred_label, target_label, weight=pos_mask)
  333. loss_conf = self.loss_conf(
  334. pred_conf, target_conf, weight=pos_and_neg_mask)
  335. loss_xy = self.loss_xy(pred_xy, target_xy, weight=pos_mask)
  336. loss_wh = self.loss_wh(pred_wh, target_wh, weight=pos_mask)
  337. return loss_cls, loss_conf, loss_xy, loss_wh
  338. def get_targets(self, anchor_list, responsible_flag_list, gt_bboxes_list,
  339. gt_labels_list):
  340. """Compute target maps for anchors in multiple images.
  341. Args:
  342. anchor_list (list[list[Tensor]]): Multi level anchors of each
  343. image. The outer list indicates images, and the inner list
  344. corresponds to feature levels of the image. Each element of
  345. the inner list is a tensor of shape (num_total_anchors, 4).
  346. responsible_flag_list (list[list[Tensor]]): Multi level responsible
  347. flags of each image. Each element is a tensor of shape
  348. (num_total_anchors, )
  349. gt_bboxes_list (list[Tensor]): Ground truth bboxes of each image.
  350. gt_labels_list (list[Tensor]): Ground truth labels of each box.
  351. Returns:
  352. tuple: Usually returns a tuple containing learning targets.
  353. - target_map_list (list[Tensor]): Target map of each level.
  354. - neg_map_list (list[Tensor]): Negative map of each level.
  355. """
  356. num_imgs = len(anchor_list)
  357. # anchor number of multi levels
  358. num_level_anchors = [anchors.size(0) for anchors in anchor_list[0]]
  359. results = multi_apply(self._get_targets_single, anchor_list,
  360. responsible_flag_list, gt_bboxes_list,
  361. gt_labels_list)
  362. all_target_maps, all_neg_maps = results
  363. assert num_imgs == len(all_target_maps) == len(all_neg_maps)
  364. target_maps_list = images_to_levels(all_target_maps, num_level_anchors)
  365. neg_maps_list = images_to_levels(all_neg_maps, num_level_anchors)
  366. return target_maps_list, neg_maps_list
  367. def _get_targets_single(self, anchors, responsible_flags, gt_bboxes,
  368. gt_labels):
  369. """Generate matching bounding box prior and converted GT.
  370. Args:
  371. anchors (list[Tensor]): Multi-level anchors of the image.
  372. responsible_flags (list[Tensor]): Multi-level responsible flags of
  373. anchors
  374. gt_bboxes (Tensor): Ground truth bboxes of single image.
  375. gt_labels (Tensor): Ground truth labels of single image.
  376. Returns:
  377. tuple:
  378. target_map (Tensor): Predication target map of each
  379. scale level, shape (num_total_anchors,
  380. 5+num_classes)
  381. neg_map (Tensor): Negative map of each scale level,
  382. shape (num_total_anchors,)
  383. """
  384. anchor_strides = []
  385. for i in range(len(anchors)):
  386. anchor_strides.append(
  387. torch.tensor(self.featmap_strides[i],
  388. device=gt_bboxes.device).repeat(len(anchors[i])))
  389. concat_anchors = torch.cat(anchors)
  390. concat_responsible_flags = torch.cat(responsible_flags)
  391. anchor_strides = torch.cat(anchor_strides)
  392. assert len(anchor_strides) == len(concat_anchors) == \
  393. len(concat_responsible_flags)
  394. assign_result = self.assigner.assign(concat_anchors,
  395. concat_responsible_flags,
  396. gt_bboxes)
  397. sampling_result = self.sampler.sample(assign_result, concat_anchors,
  398. gt_bboxes)
  399. target_map = concat_anchors.new_zeros(
  400. concat_anchors.size(0), self.num_attrib)
  401. target_map[sampling_result.pos_inds, :4] = self.bbox_coder.encode(
  402. sampling_result.pos_bboxes, sampling_result.pos_gt_bboxes,
  403. anchor_strides[sampling_result.pos_inds])
  404. target_map[sampling_result.pos_inds, 4] = 1
  405. gt_labels_one_hot = F.one_hot(
  406. gt_labels, num_classes=self.num_classes).float()
  407. if self.one_hot_smoother != 0: # label smooth
  408. gt_labels_one_hot = gt_labels_one_hot * (
  409. 1 - self.one_hot_smoother
  410. ) + self.one_hot_smoother / self.num_classes
  411. target_map[sampling_result.pos_inds, 5:] = gt_labels_one_hot[
  412. sampling_result.pos_assigned_gt_inds]
  413. neg_map = concat_anchors.new_zeros(
  414. concat_anchors.size(0), dtype=torch.uint8)
  415. neg_map[sampling_result.neg_inds] = 1
  416. return target_map, neg_map
  417. def aug_test(self, feats, img_metas, rescale=False):
  418. """Test function with test time augmentation.
  419. Args:
  420. feats (list[Tensor]): the outer list indicates test-time
  421. augmentations and inner Tensor should have a shape NxCxHxW,
  422. which contains features for all images in the batch.
  423. img_metas (list[list[dict]]): the outer list indicates test-time
  424. augs (multiscale, flip, etc.) and the inner list indicates
  425. images in a batch. each dict has image information.
  426. rescale (bool, optional): Whether to rescale the results.
  427. Defaults to False.
  428. Returns:
  429. list[ndarray]: bbox results of each class
  430. """
  431. return self.aug_test_bboxes(feats, img_metas, rescale=rescale)
  432. @force_fp32(apply_to=('pred_maps'))
  433. def onnx_export(self, pred_maps, img_metas, with_nms=True):
  434. num_levels = len(pred_maps)
  435. pred_maps_list = [pred_maps[i].detach() for i in range(num_levels)]
  436. cfg = self.test_cfg
  437. assert len(pred_maps_list) == self.num_levels
  438. device = pred_maps_list[0].device
  439. batch_size = pred_maps_list[0].shape[0]
  440. featmap_sizes = [
  441. pred_maps_list[i].shape[-2:] for i in range(self.num_levels)
  442. ]
  443. mlvl_anchors = self.prior_generator.grid_priors(
  444. featmap_sizes, device=device)
  445. # convert to tensor to keep tracing
  446. nms_pre_tensor = torch.tensor(
  447. cfg.get('nms_pre', -1), device=device, dtype=torch.long)
  448. multi_lvl_bboxes = []
  449. multi_lvl_cls_scores = []
  450. multi_lvl_conf_scores = []
  451. for i in range(self.num_levels):
  452. # get some key info for current scale
  453. pred_map = pred_maps_list[i]
  454. stride = self.featmap_strides[i]
  455. # (b,h, w, num_anchors*num_attrib) ->
  456. # (b,h*w*num_anchors, num_attrib)
  457. pred_map = pred_map.permute(0, 2, 3,
  458. 1).reshape(batch_size, -1,
  459. self.num_attrib)
  460. # Inplace operation like
  461. # ```pred_map[..., :2] = \torch.sigmoid(pred_map[..., :2])```
  462. # would create constant tensor when exporting to onnx
  463. pred_map_conf = torch.sigmoid(pred_map[..., :2])
  464. pred_map_rest = pred_map[..., 2:]
  465. pred_map = torch.cat([pred_map_conf, pred_map_rest], dim=-1)
  466. pred_map_boxes = pred_map[..., :4]
  467. multi_lvl_anchor = mlvl_anchors[i]
  468. multi_lvl_anchor = multi_lvl_anchor.expand_as(pred_map_boxes)
  469. bbox_pred = self.bbox_coder.decode(multi_lvl_anchor,
  470. pred_map_boxes, stride)
  471. # conf and cls
  472. conf_pred = torch.sigmoid(pred_map[..., 4])
  473. cls_pred = torch.sigmoid(pred_map[..., 5:]).view(
  474. batch_size, -1, self.num_classes) # Cls pred one-hot.
  475. # Get top-k prediction
  476. from mmdet.core.export import get_k_for_topk
  477. nms_pre = get_k_for_topk(nms_pre_tensor, bbox_pred.shape[1])
  478. if nms_pre > 0:
  479. _, topk_inds = conf_pred.topk(nms_pre)
  480. batch_inds = torch.arange(batch_size).view(
  481. -1, 1).expand_as(topk_inds).long()
  482. # Avoid onnx2tensorrt issue in https://github.com/NVIDIA/TensorRT/issues/1134 # noqa: E501
  483. transformed_inds = (
  484. bbox_pred.shape[1] * batch_inds + topk_inds)
  485. bbox_pred = bbox_pred.reshape(-1,
  486. 4)[transformed_inds, :].reshape(
  487. batch_size, -1, 4)
  488. cls_pred = cls_pred.reshape(
  489. -1, self.num_classes)[transformed_inds, :].reshape(
  490. batch_size, -1, self.num_classes)
  491. conf_pred = conf_pred.reshape(-1, 1)[transformed_inds].reshape(
  492. batch_size, -1)
  493. # Save the result of current scale
  494. multi_lvl_bboxes.append(bbox_pred)
  495. multi_lvl_cls_scores.append(cls_pred)
  496. multi_lvl_conf_scores.append(conf_pred)
  497. # Merge the results of different scales together
  498. batch_mlvl_bboxes = torch.cat(multi_lvl_bboxes, dim=1)
  499. batch_mlvl_scores = torch.cat(multi_lvl_cls_scores, dim=1)
  500. batch_mlvl_conf_scores = torch.cat(multi_lvl_conf_scores, dim=1)
  501. # Replace multiclass_nms with ONNX::NonMaxSuppression in deployment
  502. from mmdet.core.export import add_dummy_nms_for_onnx
  503. conf_thr = cfg.get('conf_thr', -1)
  504. score_thr = cfg.get('score_thr', -1)
  505. # follow original pipeline of YOLOv3
  506. if conf_thr > 0:
  507. mask = (batch_mlvl_conf_scores >= conf_thr).float()
  508. batch_mlvl_conf_scores *= mask
  509. if score_thr > 0:
  510. mask = (batch_mlvl_scores > score_thr).float()
  511. batch_mlvl_scores *= mask
  512. batch_mlvl_conf_scores = batch_mlvl_conf_scores.unsqueeze(2).expand_as(
  513. batch_mlvl_scores)
  514. batch_mlvl_scores = batch_mlvl_scores * batch_mlvl_conf_scores
  515. if with_nms:
  516. max_output_boxes_per_class = cfg.nms.get(
  517. 'max_output_boxes_per_class', 200)
  518. iou_threshold = cfg.nms.get('iou_threshold', 0.5)
  519. # keep aligned with original pipeline, improve
  520. # mAP by 1% for YOLOv3 in ONNX
  521. score_threshold = 0
  522. nms_pre = cfg.get('deploy_nms_pre', -1)
  523. return add_dummy_nms_for_onnx(
  524. batch_mlvl_bboxes,
  525. batch_mlvl_scores,
  526. max_output_boxes_per_class,
  527. iou_threshold,
  528. score_threshold,
  529. nms_pre,
  530. cfg.max_per_img,
  531. )
  532. else:
  533. return batch_mlvl_bboxes, batch_mlvl_scores

No Description

Contributors (3)