You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

yolact_head.py 44 kB

2 years ago
1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889909192939495969798991001011021031041051061071081091101111121131141151161171181191201211221231241251261271281291301311321331341351361371381391401411421431441451461471481491501511521531541551561571581591601611621631641651661671681691701711721731741751761771781791801811821831841851861871881891901911921931941951961971981992002012022032042052062072082092102112122132142152162172182192202212222232242252262272282292302312322332342352362372382392402412422432442452462472482492502512522532542552562572582592602612622632642652662672682692702712722732742752762772782792802812822832842852862872882892902912922932942952962972982993003013023033043053063073083093103113123133143153163173183193203213223233243253263273283293303313323333343353363373383393403413423433443453463473483493503513523533543553563573583593603613623633643653663673683693703713723733743753763773783793803813823833843853863873883893903913923933943953963973983994004014024034044054064074084094104114124134144154164174184194204214224234244254264274284294304314324334344354364374384394404414424434444454464474484494504514524534544554564574584594604614624634644654664674684694704714724734744754764774784794804814824834844854864874884894904914924934944954964974984995005015025035045055065075085095105115125135145155165175185195205215225235245255265275285295305315325335345355365375385395405415425435445455465475485495505515525535545555565575585595605615625635645655665675685695705715725735745755765775785795805815825835845855865875885895905915925935945955965975985996006016026036046056066076086096106116126136146156166176186196206216226236246256266276286296306316326336346356366376386396406416426436446456466476486496506516526536546556566576586596606616626636646656666676686696706716726736746756766776786796806816826836846856866876886896906916926936946956966976986997007017027037047057067077087097107117127137147157167177187197207217227237247257267277287297307317327337347357367377387397407417427437447457467477487497507517527537547557567577587597607617627637647657667677687697707717727737747757767777787797807817827837847857867877887897907917927937947957967977987998008018028038048058068078088098108118128138148158168178188198208218228238248258268278288298308318328338348358368378388398408418428438448458468478488498508518528538548558568578588598608618628638648658668678688698708718728738748758768778788798808818828838848858868878888898908918928938948958968978988999009019029039049059069079089099109119129139149159169179189199209219229239249259269279289299309319329339349359369379389399409419429439449459469479489499509519529539549559569579589599609619629639649659669679689699709719729739749759769779789799809819829839849859869879889899909919929939949959969979989991000100110021003100410051006100710081009101010111012101310141015101610171018
  1. # Copyright (c) OpenMMLab. All rights reserved.
  2. import numpy as np
  3. import torch
  4. import torch.nn as nn
  5. import torch.nn.functional as F
  6. from mmcv.cnn import ConvModule
  7. from mmcv.runner import BaseModule, ModuleList, force_fp32
  8. from mmdet.core import build_sampler, fast_nms, images_to_levels, multi_apply
  9. from mmdet.core.utils import select_single_mlvl
  10. from ..builder import HEADS, build_loss
  11. from .anchor_head import AnchorHead
  12. @HEADS.register_module()
  13. class YOLACTHead(AnchorHead):
  14. """YOLACT box head used in https://arxiv.org/abs/1904.02689.
  15. Note that YOLACT head is a light version of RetinaNet head.
  16. Four differences are described as follows:
  17. 1. YOLACT box head has three-times fewer anchors.
  18. 2. YOLACT box head shares the convs for box and cls branches.
  19. 3. YOLACT box head uses OHEM instead of Focal loss.
  20. 4. YOLACT box head predicts a set of mask coefficients for each box.
  21. Args:
  22. num_classes (int): Number of categories excluding the background
  23. category.
  24. in_channels (int): Number of channels in the input feature map.
  25. anchor_generator (dict): Config dict for anchor generator
  26. loss_cls (dict): Config of classification loss.
  27. loss_bbox (dict): Config of localization loss.
  28. num_head_convs (int): Number of the conv layers shared by
  29. box and cls branches.
  30. num_protos (int): Number of the mask coefficients.
  31. use_ohem (bool): If true, ``loss_single_OHEM`` will be used for
  32. cls loss calculation. If false, ``loss_single`` will be used.
  33. conv_cfg (dict): Dictionary to construct and config conv layer.
  34. norm_cfg (dict): Dictionary to construct and config norm layer.
  35. init_cfg (dict or list[dict], optional): Initialization config dict.
  36. """
  37. def __init__(self,
  38. num_classes,
  39. in_channels,
  40. anchor_generator=dict(
  41. type='AnchorGenerator',
  42. octave_base_scale=3,
  43. scales_per_octave=1,
  44. ratios=[0.5, 1.0, 2.0],
  45. strides=[8, 16, 32, 64, 128]),
  46. loss_cls=dict(
  47. type='CrossEntropyLoss',
  48. use_sigmoid=False,
  49. reduction='none',
  50. loss_weight=1.0),
  51. loss_bbox=dict(
  52. type='SmoothL1Loss', beta=1.0, loss_weight=1.5),
  53. num_head_convs=1,
  54. num_protos=32,
  55. use_ohem=True,
  56. conv_cfg=None,
  57. norm_cfg=None,
  58. init_cfg=dict(
  59. type='Xavier',
  60. distribution='uniform',
  61. bias=0,
  62. layer='Conv2d'),
  63. **kwargs):
  64. self.num_head_convs = num_head_convs
  65. self.num_protos = num_protos
  66. self.use_ohem = use_ohem
  67. self.conv_cfg = conv_cfg
  68. self.norm_cfg = norm_cfg
  69. super(YOLACTHead, self).__init__(
  70. num_classes,
  71. in_channels,
  72. loss_cls=loss_cls,
  73. loss_bbox=loss_bbox,
  74. anchor_generator=anchor_generator,
  75. init_cfg=init_cfg,
  76. **kwargs)
  77. if self.use_ohem:
  78. sampler_cfg = dict(type='PseudoSampler')
  79. self.sampler = build_sampler(sampler_cfg, context=self)
  80. self.sampling = False
  81. def _init_layers(self):
  82. """Initialize layers of the head."""
  83. self.relu = nn.ReLU(inplace=True)
  84. self.head_convs = ModuleList()
  85. for i in range(self.num_head_convs):
  86. chn = self.in_channels if i == 0 else self.feat_channels
  87. self.head_convs.append(
  88. ConvModule(
  89. chn,
  90. self.feat_channels,
  91. 3,
  92. stride=1,
  93. padding=1,
  94. conv_cfg=self.conv_cfg,
  95. norm_cfg=self.norm_cfg))
  96. self.conv_cls = nn.Conv2d(
  97. self.feat_channels,
  98. self.num_base_priors * self.cls_out_channels,
  99. 3,
  100. padding=1)
  101. self.conv_reg = nn.Conv2d(
  102. self.feat_channels, self.num_base_priors * 4, 3, padding=1)
  103. self.conv_coeff = nn.Conv2d(
  104. self.feat_channels,
  105. self.num_base_priors * self.num_protos,
  106. 3,
  107. padding=1)
  108. def forward_single(self, x):
  109. """Forward feature of a single scale level.
  110. Args:
  111. x (Tensor): Features of a single scale level.
  112. Returns:
  113. tuple:
  114. cls_score (Tensor): Cls scores for a single scale level \
  115. the channels number is num_anchors * num_classes.
  116. bbox_pred (Tensor): Box energies / deltas for a single scale \
  117. level, the channels number is num_anchors * 4.
  118. coeff_pred (Tensor): Mask coefficients for a single scale \
  119. level, the channels number is num_anchors * num_protos.
  120. """
  121. for head_conv in self.head_convs:
  122. x = head_conv(x)
  123. cls_score = self.conv_cls(x)
  124. bbox_pred = self.conv_reg(x)
  125. coeff_pred = self.conv_coeff(x).tanh()
  126. return cls_score, bbox_pred, coeff_pred
  127. @force_fp32(apply_to=('cls_scores', 'bbox_preds'))
  128. def loss(self,
  129. cls_scores,
  130. bbox_preds,
  131. gt_bboxes,
  132. gt_labels,
  133. img_metas,
  134. gt_bboxes_ignore=None):
  135. """A combination of the func:``AnchorHead.loss`` and
  136. func:``SSDHead.loss``.
  137. When ``self.use_ohem == True``, it functions like ``SSDHead.loss``,
  138. otherwise, it follows ``AnchorHead.loss``. Besides, it additionally
  139. returns ``sampling_results``.
  140. Args:
  141. cls_scores (list[Tensor]): Box scores for each scale level
  142. Has shape (N, num_anchors * num_classes, H, W)
  143. bbox_preds (list[Tensor]): Box energies / deltas for each scale
  144. level with shape (N, num_anchors * 4, H, W)
  145. gt_bboxes (list[Tensor]): Ground truth bboxes for each image with
  146. shape (num_gts, 4) in [tl_x, tl_y, br_x, br_y] format.
  147. gt_labels (list[Tensor]): Class indices corresponding to each box
  148. img_metas (list[dict]): Meta information of each image, e.g.,
  149. image size, scaling factor, etc.
  150. gt_bboxes_ignore (None | list[Tensor]): Specify which bounding
  151. boxes can be ignored when computing the loss. Default: None
  152. Returns:
  153. tuple:
  154. dict[str, Tensor]: A dictionary of loss components.
  155. List[:obj:``SamplingResult``]: Sampler results for each image.
  156. """
  157. featmap_sizes = [featmap.size()[-2:] for featmap in cls_scores]
  158. assert len(featmap_sizes) == self.prior_generator.num_levels
  159. device = cls_scores[0].device
  160. anchor_list, valid_flag_list = self.get_anchors(
  161. featmap_sizes, img_metas, device=device)
  162. label_channels = self.cls_out_channels if self.use_sigmoid_cls else 1
  163. cls_reg_targets = self.get_targets(
  164. anchor_list,
  165. valid_flag_list,
  166. gt_bboxes,
  167. img_metas,
  168. gt_bboxes_ignore_list=gt_bboxes_ignore,
  169. gt_labels_list=gt_labels,
  170. label_channels=label_channels,
  171. unmap_outputs=not self.use_ohem,
  172. return_sampling_results=True)
  173. if cls_reg_targets is None:
  174. return None
  175. (labels_list, label_weights_list, bbox_targets_list, bbox_weights_list,
  176. num_total_pos, num_total_neg, sampling_results) = cls_reg_targets
  177. if self.use_ohem:
  178. num_images = len(img_metas)
  179. all_cls_scores = torch.cat([
  180. s.permute(0, 2, 3, 1).reshape(
  181. num_images, -1, self.cls_out_channels) for s in cls_scores
  182. ], 1)
  183. all_labels = torch.cat(labels_list, -1).view(num_images, -1)
  184. all_label_weights = torch.cat(label_weights_list,
  185. -1).view(num_images, -1)
  186. all_bbox_preds = torch.cat([
  187. b.permute(0, 2, 3, 1).reshape(num_images, -1, 4)
  188. for b in bbox_preds
  189. ], -2)
  190. all_bbox_targets = torch.cat(bbox_targets_list,
  191. -2).view(num_images, -1, 4)
  192. all_bbox_weights = torch.cat(bbox_weights_list,
  193. -2).view(num_images, -1, 4)
  194. # concat all level anchors to a single tensor
  195. all_anchors = []
  196. for i in range(num_images):
  197. all_anchors.append(torch.cat(anchor_list[i]))
  198. # check NaN and Inf
  199. assert torch.isfinite(all_cls_scores).all().item(), \
  200. 'classification scores become infinite or NaN!'
  201. assert torch.isfinite(all_bbox_preds).all().item(), \
  202. 'bbox predications become infinite or NaN!'
  203. losses_cls, losses_bbox = multi_apply(
  204. self.loss_single_OHEM,
  205. all_cls_scores,
  206. all_bbox_preds,
  207. all_anchors,
  208. all_labels,
  209. all_label_weights,
  210. all_bbox_targets,
  211. all_bbox_weights,
  212. num_total_samples=num_total_pos)
  213. else:
  214. num_total_samples = (
  215. num_total_pos +
  216. num_total_neg if self.sampling else num_total_pos)
  217. # anchor number of multi levels
  218. num_level_anchors = [anchors.size(0) for anchors in anchor_list[0]]
  219. # concat all level anchors and flags to a single tensor
  220. concat_anchor_list = []
  221. for i in range(len(anchor_list)):
  222. concat_anchor_list.append(torch.cat(anchor_list[i]))
  223. all_anchor_list = images_to_levels(concat_anchor_list,
  224. num_level_anchors)
  225. losses_cls, losses_bbox = multi_apply(
  226. self.loss_single,
  227. cls_scores,
  228. bbox_preds,
  229. all_anchor_list,
  230. labels_list,
  231. label_weights_list,
  232. bbox_targets_list,
  233. bbox_weights_list,
  234. num_total_samples=num_total_samples)
  235. return dict(
  236. loss_cls=losses_cls, loss_bbox=losses_bbox), sampling_results
  237. def loss_single_OHEM(self, cls_score, bbox_pred, anchors, labels,
  238. label_weights, bbox_targets, bbox_weights,
  239. num_total_samples):
  240. """"See func:``SSDHead.loss``."""
  241. loss_cls_all = self.loss_cls(cls_score, labels, label_weights)
  242. # FG cat_id: [0, num_classes -1], BG cat_id: num_classes
  243. pos_inds = ((labels >= 0) & (labels < self.num_classes)).nonzero(
  244. as_tuple=False).reshape(-1)
  245. neg_inds = (labels == self.num_classes).nonzero(
  246. as_tuple=False).view(-1)
  247. num_pos_samples = pos_inds.size(0)
  248. if num_pos_samples == 0:
  249. num_neg_samples = neg_inds.size(0)
  250. else:
  251. num_neg_samples = self.train_cfg.neg_pos_ratio * num_pos_samples
  252. if num_neg_samples > neg_inds.size(0):
  253. num_neg_samples = neg_inds.size(0)
  254. topk_loss_cls_neg, _ = loss_cls_all[neg_inds].topk(num_neg_samples)
  255. loss_cls_pos = loss_cls_all[pos_inds].sum()
  256. loss_cls_neg = topk_loss_cls_neg.sum()
  257. loss_cls = (loss_cls_pos + loss_cls_neg) / num_total_samples
  258. if self.reg_decoded_bbox:
  259. # When the regression loss (e.g. `IouLoss`, `GIouLoss`)
  260. # is applied directly on the decoded bounding boxes, it
  261. # decodes the already encoded coordinates to absolute format.
  262. bbox_pred = self.bbox_coder.decode(anchors, bbox_pred)
  263. loss_bbox = self.loss_bbox(
  264. bbox_pred,
  265. bbox_targets,
  266. bbox_weights,
  267. avg_factor=num_total_samples)
  268. return loss_cls[None], loss_bbox
  269. @force_fp32(apply_to=('cls_scores', 'bbox_preds', 'coeff_preds'))
  270. def get_bboxes(self,
  271. cls_scores,
  272. bbox_preds,
  273. coeff_preds,
  274. img_metas,
  275. cfg=None,
  276. rescale=False):
  277. """"Similar to func:``AnchorHead.get_bboxes``, but additionally
  278. processes coeff_preds.
  279. Args:
  280. cls_scores (list[Tensor]): Box scores for each scale level
  281. with shape (N, num_anchors * num_classes, H, W)
  282. bbox_preds (list[Tensor]): Box energies / deltas for each scale
  283. level with shape (N, num_anchors * 4, H, W)
  284. coeff_preds (list[Tensor]): Mask coefficients for each scale
  285. level with shape (N, num_anchors * num_protos, H, W)
  286. img_metas (list[dict]): Meta information of each image, e.g.,
  287. image size, scaling factor, etc.
  288. cfg (mmcv.Config | None): Test / postprocessing configuration,
  289. if None, test_cfg would be used
  290. rescale (bool): If True, return boxes in original image space.
  291. Default: False.
  292. Returns:
  293. list[tuple[Tensor, Tensor, Tensor]]: Each item in result_list is
  294. a 3-tuple. The first item is an (n, 5) tensor, where the
  295. first 4 columns are bounding box positions
  296. (tl_x, tl_y, br_x, br_y) and the 5-th column is a score
  297. between 0 and 1. The second item is an (n,) tensor where each
  298. item is the predicted class label of the corresponding box.
  299. The third item is an (n, num_protos) tensor where each item
  300. is the predicted mask coefficients of instance inside the
  301. corresponding box.
  302. """
  303. assert len(cls_scores) == len(bbox_preds)
  304. num_levels = len(cls_scores)
  305. device = cls_scores[0].device
  306. featmap_sizes = [cls_scores[i].shape[-2:] for i in range(num_levels)]
  307. mlvl_anchors = self.prior_generator.grid_priors(
  308. featmap_sizes, device=device)
  309. det_bboxes = []
  310. det_labels = []
  311. det_coeffs = []
  312. for img_id in range(len(img_metas)):
  313. cls_score_list = select_single_mlvl(cls_scores, img_id)
  314. bbox_pred_list = select_single_mlvl(bbox_preds, img_id)
  315. coeff_pred_list = select_single_mlvl(coeff_preds, img_id)
  316. img_shape = img_metas[img_id]['img_shape']
  317. scale_factor = img_metas[img_id]['scale_factor']
  318. bbox_res = self._get_bboxes_single(cls_score_list, bbox_pred_list,
  319. coeff_pred_list, mlvl_anchors,
  320. img_shape, scale_factor, cfg,
  321. rescale)
  322. det_bboxes.append(bbox_res[0])
  323. det_labels.append(bbox_res[1])
  324. det_coeffs.append(bbox_res[2])
  325. return det_bboxes, det_labels, det_coeffs
  326. def _get_bboxes_single(self,
  327. cls_score_list,
  328. bbox_pred_list,
  329. coeff_preds_list,
  330. mlvl_anchors,
  331. img_shape,
  332. scale_factor,
  333. cfg,
  334. rescale=False):
  335. """"Similar to func:``AnchorHead._get_bboxes_single``, but additionally
  336. processes coeff_preds_list and uses fast NMS instead of traditional
  337. NMS.
  338. Args:
  339. cls_score_list (list[Tensor]): Box scores for a single scale level
  340. Has shape (num_anchors * num_classes, H, W).
  341. bbox_pred_list (list[Tensor]): Box energies / deltas for a single
  342. scale level with shape (num_anchors * 4, H, W).
  343. coeff_preds_list (list[Tensor]): Mask coefficients for a single
  344. scale level with shape (num_anchors * num_protos, H, W).
  345. mlvl_anchors (list[Tensor]): Box reference for a single scale level
  346. with shape (num_total_anchors, 4).
  347. img_shape (tuple[int]): Shape of the input image,
  348. (height, width, 3).
  349. scale_factor (ndarray): Scale factor of the image arange as
  350. (w_scale, h_scale, w_scale, h_scale).
  351. cfg (mmcv.Config): Test / postprocessing configuration,
  352. if None, test_cfg would be used.
  353. rescale (bool): If True, return boxes in original image space.
  354. Returns:
  355. tuple[Tensor, Tensor, Tensor]: The first item is an (n, 5) tensor,
  356. where the first 4 columns are bounding box positions
  357. (tl_x, tl_y, br_x, br_y) and the 5-th column is a score between
  358. 0 and 1. The second item is an (n,) tensor where each item is
  359. the predicted class label of the corresponding box. The third
  360. item is an (n, num_protos) tensor where each item is the
  361. predicted mask coefficients of instance inside the
  362. corresponding box.
  363. """
  364. cfg = self.test_cfg if cfg is None else cfg
  365. assert len(cls_score_list) == len(bbox_pred_list) == len(mlvl_anchors)
  366. nms_pre = cfg.get('nms_pre', -1)
  367. mlvl_bboxes = []
  368. mlvl_scores = []
  369. mlvl_coeffs = []
  370. for cls_score, bbox_pred, coeff_pred, anchors in \
  371. zip(cls_score_list, bbox_pred_list,
  372. coeff_preds_list, mlvl_anchors):
  373. assert cls_score.size()[-2:] == bbox_pred.size()[-2:]
  374. cls_score = cls_score.permute(1, 2,
  375. 0).reshape(-1, self.cls_out_channels)
  376. if self.use_sigmoid_cls:
  377. scores = cls_score.sigmoid()
  378. else:
  379. scores = cls_score.softmax(-1)
  380. bbox_pred = bbox_pred.permute(1, 2, 0).reshape(-1, 4)
  381. coeff_pred = coeff_pred.permute(1, 2,
  382. 0).reshape(-1, self.num_protos)
  383. if 0 < nms_pre < scores.shape[0]:
  384. # Get maximum scores for foreground classes.
  385. if self.use_sigmoid_cls:
  386. max_scores, _ = scores.max(dim=1)
  387. else:
  388. # remind that we set FG labels to [0, num_class-1]
  389. # since mmdet v2.0
  390. # BG cat_id: num_class
  391. max_scores, _ = scores[:, :-1].max(dim=1)
  392. _, topk_inds = max_scores.topk(nms_pre)
  393. anchors = anchors[topk_inds, :]
  394. bbox_pred = bbox_pred[topk_inds, :]
  395. scores = scores[topk_inds, :]
  396. coeff_pred = coeff_pred[topk_inds, :]
  397. bboxes = self.bbox_coder.decode(
  398. anchors, bbox_pred, max_shape=img_shape)
  399. mlvl_bboxes.append(bboxes)
  400. mlvl_scores.append(scores)
  401. mlvl_coeffs.append(coeff_pred)
  402. mlvl_bboxes = torch.cat(mlvl_bboxes)
  403. if rescale:
  404. mlvl_bboxes /= mlvl_bboxes.new_tensor(scale_factor)
  405. mlvl_scores = torch.cat(mlvl_scores)
  406. mlvl_coeffs = torch.cat(mlvl_coeffs)
  407. if self.use_sigmoid_cls:
  408. # Add a dummy background class to the backend when using sigmoid
  409. # remind that we set FG labels to [0, num_class-1] since mmdet v2.0
  410. # BG cat_id: num_class
  411. padding = mlvl_scores.new_zeros(mlvl_scores.shape[0], 1)
  412. mlvl_scores = torch.cat([mlvl_scores, padding], dim=1)
  413. det_bboxes, det_labels, det_coeffs = fast_nms(mlvl_bboxes, mlvl_scores,
  414. mlvl_coeffs,
  415. cfg.score_thr,
  416. cfg.iou_thr, cfg.top_k,
  417. cfg.max_per_img)
  418. return det_bboxes, det_labels, det_coeffs
  419. @HEADS.register_module()
  420. class YOLACTSegmHead(BaseModule):
  421. """YOLACT segmentation head used in https://arxiv.org/abs/1904.02689.
  422. Apply a semantic segmentation loss on feature space using layers that are
  423. only evaluated during training to increase performance with no speed
  424. penalty.
  425. Args:
  426. in_channels (int): Number of channels in the input feature map.
  427. num_classes (int): Number of categories excluding the background
  428. category.
  429. loss_segm (dict): Config of semantic segmentation loss.
  430. init_cfg (dict or list[dict], optional): Initialization config dict.
  431. """
  432. def __init__(self,
  433. num_classes,
  434. in_channels=256,
  435. loss_segm=dict(
  436. type='CrossEntropyLoss',
  437. use_sigmoid=True,
  438. loss_weight=1.0),
  439. init_cfg=dict(
  440. type='Xavier',
  441. distribution='uniform',
  442. override=dict(name='segm_conv'))):
  443. super(YOLACTSegmHead, self).__init__(init_cfg)
  444. self.in_channels = in_channels
  445. self.num_classes = num_classes
  446. self.loss_segm = build_loss(loss_segm)
  447. self._init_layers()
  448. self.fp16_enabled = False
  449. def _init_layers(self):
  450. """Initialize layers of the head."""
  451. self.segm_conv = nn.Conv2d(
  452. self.in_channels, self.num_classes, kernel_size=1)
  453. def forward(self, x):
  454. """Forward feature from the upstream network.
  455. Args:
  456. x (Tensor): Feature from the upstream network, which is
  457. a 4D-tensor.
  458. Returns:
  459. Tensor: Predicted semantic segmentation map with shape
  460. (N, num_classes, H, W).
  461. """
  462. return self.segm_conv(x)
  463. @force_fp32(apply_to=('segm_pred', ))
  464. def loss(self, segm_pred, gt_masks, gt_labels):
  465. """Compute loss of the head.
  466. Args:
  467. segm_pred (list[Tensor]): Predicted semantic segmentation map
  468. with shape (N, num_classes, H, W).
  469. gt_masks (list[Tensor]): Ground truth masks for each image with
  470. the same shape of the input image.
  471. gt_labels (list[Tensor]): Class indices corresponding to each box.
  472. Returns:
  473. dict[str, Tensor]: A dictionary of loss components.
  474. """
  475. loss_segm = []
  476. num_imgs, num_classes, mask_h, mask_w = segm_pred.size()
  477. for idx in range(num_imgs):
  478. cur_segm_pred = segm_pred[idx]
  479. cur_gt_masks = gt_masks[idx].float()
  480. cur_gt_labels = gt_labels[idx]
  481. segm_targets = self.get_targets(cur_segm_pred, cur_gt_masks,
  482. cur_gt_labels)
  483. if segm_targets is None:
  484. loss = self.loss_segm(cur_segm_pred,
  485. torch.zeros_like(cur_segm_pred),
  486. torch.zeros_like(cur_segm_pred))
  487. else:
  488. loss = self.loss_segm(
  489. cur_segm_pred,
  490. segm_targets,
  491. avg_factor=num_imgs * mask_h * mask_w)
  492. loss_segm.append(loss)
  493. return dict(loss_segm=loss_segm)
  494. def get_targets(self, segm_pred, gt_masks, gt_labels):
  495. """Compute semantic segmentation targets for each image.
  496. Args:
  497. segm_pred (Tensor): Predicted semantic segmentation map
  498. with shape (num_classes, H, W).
  499. gt_masks (Tensor): Ground truth masks for each image with
  500. the same shape of the input image.
  501. gt_labels (Tensor): Class indices corresponding to each box.
  502. Returns:
  503. Tensor: Semantic segmentation targets with shape
  504. (num_classes, H, W).
  505. """
  506. if gt_masks.size(0) == 0:
  507. return None
  508. num_classes, mask_h, mask_w = segm_pred.size()
  509. with torch.no_grad():
  510. downsampled_masks = F.interpolate(
  511. gt_masks.unsqueeze(0), (mask_h, mask_w),
  512. mode='bilinear',
  513. align_corners=False).squeeze(0)
  514. downsampled_masks = downsampled_masks.gt(0.5).float()
  515. segm_targets = torch.zeros_like(segm_pred, requires_grad=False)
  516. for obj_idx in range(downsampled_masks.size(0)):
  517. segm_targets[gt_labels[obj_idx] - 1] = torch.max(
  518. segm_targets[gt_labels[obj_idx] - 1],
  519. downsampled_masks[obj_idx])
  520. return segm_targets
  521. def simple_test(self, feats, img_metas, rescale=False):
  522. """Test function without test-time augmentation."""
  523. raise NotImplementedError(
  524. 'simple_test of YOLACTSegmHead is not implemented '
  525. 'because this head is only evaluated during training')
  526. @HEADS.register_module()
  527. class YOLACTProtonet(BaseModule):
  528. """YOLACT mask head used in https://arxiv.org/abs/1904.02689.
  529. This head outputs the mask prototypes for YOLACT.
  530. Args:
  531. in_channels (int): Number of channels in the input feature map.
  532. proto_channels (tuple[int]): Output channels of protonet convs.
  533. proto_kernel_sizes (tuple[int]): Kernel sizes of protonet convs.
  534. include_last_relu (Bool): If keep the last relu of protonet.
  535. num_protos (int): Number of prototypes.
  536. num_classes (int): Number of categories excluding the background
  537. category.
  538. loss_mask_weight (float): Reweight the mask loss by this factor.
  539. max_masks_to_train (int): Maximum number of masks to train for
  540. each image.
  541. init_cfg (dict or list[dict], optional): Initialization config dict.
  542. """
  543. def __init__(self,
  544. num_classes,
  545. in_channels=256,
  546. proto_channels=(256, 256, 256, None, 256, 32),
  547. proto_kernel_sizes=(3, 3, 3, -2, 3, 1),
  548. include_last_relu=True,
  549. num_protos=32,
  550. loss_mask_weight=1.0,
  551. max_masks_to_train=100,
  552. init_cfg=dict(
  553. type='Xavier',
  554. distribution='uniform',
  555. override=dict(name='protonet'))):
  556. super(YOLACTProtonet, self).__init__(init_cfg)
  557. self.in_channels = in_channels
  558. self.proto_channels = proto_channels
  559. self.proto_kernel_sizes = proto_kernel_sizes
  560. self.include_last_relu = include_last_relu
  561. self.protonet = self._init_layers()
  562. self.loss_mask_weight = loss_mask_weight
  563. self.num_protos = num_protos
  564. self.num_classes = num_classes
  565. self.max_masks_to_train = max_masks_to_train
  566. self.fp16_enabled = False
  567. def _init_layers(self):
  568. """A helper function to take a config setting and turn it into a
  569. network."""
  570. # Possible patterns:
  571. # ( 256, 3) -> conv
  572. # ( 256,-2) -> deconv
  573. # (None,-2) -> bilinear interpolate
  574. in_channels = self.in_channels
  575. protonets = ModuleList()
  576. for num_channels, kernel_size in zip(self.proto_channels,
  577. self.proto_kernel_sizes):
  578. if kernel_size > 0:
  579. layer = nn.Conv2d(
  580. in_channels,
  581. num_channels,
  582. kernel_size,
  583. padding=kernel_size // 2)
  584. else:
  585. if num_channels is None:
  586. layer = InterpolateModule(
  587. scale_factor=-kernel_size,
  588. mode='bilinear',
  589. align_corners=False)
  590. else:
  591. layer = nn.ConvTranspose2d(
  592. in_channels,
  593. num_channels,
  594. -kernel_size,
  595. padding=kernel_size // 2)
  596. protonets.append(layer)
  597. protonets.append(nn.ReLU(inplace=True))
  598. in_channels = num_channels if num_channels is not None \
  599. else in_channels
  600. if not self.include_last_relu:
  601. protonets = protonets[:-1]
  602. return nn.Sequential(*protonets)
  603. def forward_dummy(self, x):
  604. prototypes = self.protonet(x)
  605. return prototypes
  606. def forward(self, x, coeff_pred, bboxes, img_meta, sampling_results=None):
  607. """Forward feature from the upstream network to get prototypes and
  608. linearly combine the prototypes, using masks coefficients, into
  609. instance masks. Finally, crop the instance masks with given bboxes.
  610. Args:
  611. x (Tensor): Feature from the upstream network, which is
  612. a 4D-tensor.
  613. coeff_pred (list[Tensor]): Mask coefficients for each scale
  614. level with shape (N, num_anchors * num_protos, H, W).
  615. bboxes (list[Tensor]): Box used for cropping with shape
  616. (N, num_anchors * 4, H, W). During training, they are
  617. ground truth boxes. During testing, they are predicted
  618. boxes.
  619. img_meta (list[dict]): Meta information of each image, e.g.,
  620. image size, scaling factor, etc.
  621. sampling_results (List[:obj:``SamplingResult``]): Sampler results
  622. for each image.
  623. Returns:
  624. list[Tensor]: Predicted instance segmentation masks.
  625. """
  626. prototypes = self.protonet(x)
  627. prototypes = prototypes.permute(0, 2, 3, 1).contiguous()
  628. num_imgs = x.size(0)
  629. # The reason for not using self.training is that
  630. # val workflow will have a dimension mismatch error.
  631. # Note that this writing method is very tricky.
  632. # Fix https://github.com/open-mmlab/mmdetection/issues/5978
  633. is_train_or_val_workflow = (coeff_pred[0].dim() == 4)
  634. # Train or val workflow
  635. if is_train_or_val_workflow:
  636. coeff_pred_list = []
  637. for coeff_pred_per_level in coeff_pred:
  638. coeff_pred_per_level = \
  639. coeff_pred_per_level.permute(
  640. 0, 2, 3, 1).reshape(num_imgs, -1, self.num_protos)
  641. coeff_pred_list.append(coeff_pred_per_level)
  642. coeff_pred = torch.cat(coeff_pred_list, dim=1)
  643. mask_pred_list = []
  644. for idx in range(num_imgs):
  645. cur_prototypes = prototypes[idx]
  646. cur_coeff_pred = coeff_pred[idx]
  647. cur_bboxes = bboxes[idx]
  648. cur_img_meta = img_meta[idx]
  649. # Testing state
  650. if not is_train_or_val_workflow:
  651. bboxes_for_cropping = cur_bboxes
  652. else:
  653. cur_sampling_results = sampling_results[idx]
  654. pos_assigned_gt_inds = \
  655. cur_sampling_results.pos_assigned_gt_inds
  656. bboxes_for_cropping = cur_bboxes[pos_assigned_gt_inds].clone()
  657. pos_inds = cur_sampling_results.pos_inds
  658. cur_coeff_pred = cur_coeff_pred[pos_inds]
  659. # Linearly combine the prototypes with the mask coefficients
  660. mask_pred = cur_prototypes @ cur_coeff_pred.t()
  661. mask_pred = torch.sigmoid(mask_pred)
  662. h, w = cur_img_meta['img_shape'][:2]
  663. bboxes_for_cropping[:, 0] /= w
  664. bboxes_for_cropping[:, 1] /= h
  665. bboxes_for_cropping[:, 2] /= w
  666. bboxes_for_cropping[:, 3] /= h
  667. mask_pred = self.crop(mask_pred, bboxes_for_cropping)
  668. mask_pred = mask_pred.permute(2, 0, 1).contiguous()
  669. mask_pred_list.append(mask_pred)
  670. return mask_pred_list
  671. @force_fp32(apply_to=('mask_pred', ))
  672. def loss(self, mask_pred, gt_masks, gt_bboxes, img_meta, sampling_results):
  673. """Compute loss of the head.
  674. Args:
  675. mask_pred (list[Tensor]): Predicted prototypes with shape
  676. (num_classes, H, W).
  677. gt_masks (list[Tensor]): Ground truth masks for each image with
  678. the same shape of the input image.
  679. gt_bboxes (list[Tensor]): Ground truth bboxes for each image with
  680. shape (num_gts, 4) in [tl_x, tl_y, br_x, br_y] format.
  681. img_meta (list[dict]): Meta information of each image, e.g.,
  682. image size, scaling factor, etc.
  683. sampling_results (List[:obj:``SamplingResult``]): Sampler results
  684. for each image.
  685. Returns:
  686. dict[str, Tensor]: A dictionary of loss components.
  687. """
  688. loss_mask = []
  689. num_imgs = len(mask_pred)
  690. total_pos = 0
  691. for idx in range(num_imgs):
  692. cur_mask_pred = mask_pred[idx]
  693. cur_gt_masks = gt_masks[idx].float()
  694. cur_gt_bboxes = gt_bboxes[idx]
  695. cur_img_meta = img_meta[idx]
  696. cur_sampling_results = sampling_results[idx]
  697. pos_assigned_gt_inds = cur_sampling_results.pos_assigned_gt_inds
  698. num_pos = pos_assigned_gt_inds.size(0)
  699. # Since we're producing (near) full image masks,
  700. # it'd take too much vram to backprop on every single mask.
  701. # Thus we select only a subset.
  702. if num_pos > self.max_masks_to_train:
  703. perm = torch.randperm(num_pos)
  704. select = perm[:self.max_masks_to_train]
  705. cur_mask_pred = cur_mask_pred[select]
  706. pos_assigned_gt_inds = pos_assigned_gt_inds[select]
  707. num_pos = self.max_masks_to_train
  708. total_pos += num_pos
  709. gt_bboxes_for_reweight = cur_gt_bboxes[pos_assigned_gt_inds]
  710. mask_targets = self.get_targets(cur_mask_pred, cur_gt_masks,
  711. pos_assigned_gt_inds)
  712. if num_pos == 0:
  713. loss = cur_mask_pred.sum() * 0.
  714. elif mask_targets is None:
  715. loss = F.binary_cross_entropy(cur_mask_pred,
  716. torch.zeros_like(cur_mask_pred),
  717. torch.zeros_like(cur_mask_pred))
  718. else:
  719. cur_mask_pred = torch.clamp(cur_mask_pred, 0, 1)
  720. loss = F.binary_cross_entropy(
  721. cur_mask_pred, mask_targets,
  722. reduction='none') * self.loss_mask_weight
  723. h, w = cur_img_meta['img_shape'][:2]
  724. gt_bboxes_width = (gt_bboxes_for_reweight[:, 2] -
  725. gt_bboxes_for_reweight[:, 0]) / w
  726. gt_bboxes_height = (gt_bboxes_for_reweight[:, 3] -
  727. gt_bboxes_for_reweight[:, 1]) / h
  728. loss = loss.mean(dim=(1,
  729. 2)) / gt_bboxes_width / gt_bboxes_height
  730. loss = torch.sum(loss)
  731. loss_mask.append(loss)
  732. if total_pos == 0:
  733. total_pos += 1 # avoid nan
  734. loss_mask = [x / total_pos for x in loss_mask]
  735. return dict(loss_mask=loss_mask)
  736. def get_targets(self, mask_pred, gt_masks, pos_assigned_gt_inds):
  737. """Compute instance segmentation targets for each image.
  738. Args:
  739. mask_pred (Tensor): Predicted prototypes with shape
  740. (num_classes, H, W).
  741. gt_masks (Tensor): Ground truth masks for each image with
  742. the same shape of the input image.
  743. pos_assigned_gt_inds (Tensor): GT indices of the corresponding
  744. positive samples.
  745. Returns:
  746. Tensor: Instance segmentation targets with shape
  747. (num_instances, H, W).
  748. """
  749. if gt_masks.size(0) == 0:
  750. return None
  751. mask_h, mask_w = mask_pred.shape[-2:]
  752. gt_masks = F.interpolate(
  753. gt_masks.unsqueeze(0), (mask_h, mask_w),
  754. mode='bilinear',
  755. align_corners=False).squeeze(0)
  756. gt_masks = gt_masks.gt(0.5).float()
  757. mask_targets = gt_masks[pos_assigned_gt_inds]
  758. return mask_targets
  759. def get_seg_masks(self, mask_pred, label_pred, img_meta, rescale):
  760. """Resize, binarize, and format the instance mask predictions.
  761. Args:
  762. mask_pred (Tensor): shape (N, H, W).
  763. label_pred (Tensor): shape (N, ).
  764. img_meta (dict): Meta information of each image, e.g.,
  765. image size, scaling factor, etc.
  766. rescale (bool): If rescale is False, then returned masks will
  767. fit the scale of imgs[0].
  768. Returns:
  769. list[ndarray]: Mask predictions grouped by their predicted classes.
  770. """
  771. ori_shape = img_meta['ori_shape']
  772. scale_factor = img_meta['scale_factor']
  773. if rescale:
  774. img_h, img_w = ori_shape[:2]
  775. else:
  776. img_h = np.round(ori_shape[0] * scale_factor[1]).astype(np.int32)
  777. img_w = np.round(ori_shape[1] * scale_factor[0]).astype(np.int32)
  778. cls_segms = [[] for _ in range(self.num_classes)]
  779. if mask_pred.size(0) == 0:
  780. return cls_segms
  781. mask_pred = F.interpolate(
  782. mask_pred.unsqueeze(0), (img_h, img_w),
  783. mode='bilinear',
  784. align_corners=False).squeeze(0) > 0.5
  785. mask_pred = mask_pred.cpu().numpy().astype(np.uint8)
  786. for m, l in zip(mask_pred, label_pred):
  787. cls_segms[l].append(m)
  788. return cls_segms
  789. def crop(self, masks, boxes, padding=1):
  790. """Crop predicted masks by zeroing out everything not in the predicted
  791. bbox.
  792. Args:
  793. masks (Tensor): shape [H, W, N].
  794. boxes (Tensor): bbox coords in relative point form with
  795. shape [N, 4].
  796. Return:
  797. Tensor: The cropped masks.
  798. """
  799. h, w, n = masks.size()
  800. x1, x2 = self.sanitize_coordinates(
  801. boxes[:, 0], boxes[:, 2], w, padding, cast=False)
  802. y1, y2 = self.sanitize_coordinates(
  803. boxes[:, 1], boxes[:, 3], h, padding, cast=False)
  804. rows = torch.arange(
  805. w, device=masks.device, dtype=x1.dtype).view(1, -1,
  806. 1).expand(h, w, n)
  807. cols = torch.arange(
  808. h, device=masks.device, dtype=x1.dtype).view(-1, 1,
  809. 1).expand(h, w, n)
  810. masks_left = rows >= x1.view(1, 1, -1)
  811. masks_right = rows < x2.view(1, 1, -1)
  812. masks_up = cols >= y1.view(1, 1, -1)
  813. masks_down = cols < y2.view(1, 1, -1)
  814. crop_mask = masks_left * masks_right * masks_up * masks_down
  815. return masks * crop_mask.float()
  816. def sanitize_coordinates(self, x1, x2, img_size, padding=0, cast=True):
  817. """Sanitizes the input coordinates so that x1 < x2, x1 != x2, x1 >= 0,
  818. and x2 <= image_size. Also converts from relative to absolute
  819. coordinates and casts the results to long tensors.
  820. Warning: this does things in-place behind the scenes so
  821. copy if necessary.
  822. Args:
  823. _x1 (Tensor): shape (N, ).
  824. _x2 (Tensor): shape (N, ).
  825. img_size (int): Size of the input image.
  826. padding (int): x1 >= padding, x2 <= image_size-padding.
  827. cast (bool): If cast is false, the result won't be cast to longs.
  828. Returns:
  829. tuple:
  830. x1 (Tensor): Sanitized _x1.
  831. x2 (Tensor): Sanitized _x2.
  832. """
  833. x1 = x1 * img_size
  834. x2 = x2 * img_size
  835. if cast:
  836. x1 = x1.long()
  837. x2 = x2.long()
  838. x1 = torch.min(x1, x2)
  839. x2 = torch.max(x1, x2)
  840. x1 = torch.clamp(x1 - padding, min=0)
  841. x2 = torch.clamp(x2 + padding, max=img_size)
  842. return x1, x2
  843. def simple_test(self,
  844. feats,
  845. det_bboxes,
  846. det_labels,
  847. det_coeffs,
  848. img_metas,
  849. rescale=False):
  850. """Test function without test-time augmentation.
  851. Args:
  852. feats (tuple[torch.Tensor]): Multi-level features from the
  853. upstream network, each is a 4D-tensor.
  854. det_bboxes (list[Tensor]): BBox results of each image. each
  855. element is (n, 5) tensor, where 5 represent
  856. (tl_x, tl_y, br_x, br_y, score) and the score between 0 and 1.
  857. det_labels (list[Tensor]): BBox results of each image. each
  858. element is (n, ) tensor, each element represents the class
  859. label of the corresponding box.
  860. det_coeffs (list[Tensor]): BBox coefficient of each image. each
  861. element is (n, m) tensor, m is vector length.
  862. img_metas (list[dict]): Meta information of each image, e.g.,
  863. image size, scaling factor, etc.
  864. rescale (bool, optional): Whether to rescale the results.
  865. Defaults to False.
  866. Returns:
  867. list[list]: encoded masks. The c-th item in the outer list
  868. corresponds to the c-th class. Given the c-th outer list, the
  869. i-th item in that inner list is the mask for the i-th box with
  870. class label c.
  871. """
  872. num_imgs = len(img_metas)
  873. scale_factors = tuple(meta['scale_factor'] for meta in img_metas)
  874. if all(det_bbox.shape[0] == 0 for det_bbox in det_bboxes):
  875. segm_results = [[[] for _ in range(self.num_classes)]
  876. for _ in range(num_imgs)]
  877. else:
  878. # if det_bboxes is rescaled to the original image size, we need to
  879. # rescale it back to the testing scale to obtain RoIs.
  880. if rescale and not isinstance(scale_factors[0], float):
  881. scale_factors = [
  882. torch.from_numpy(scale_factor).to(det_bboxes[0].device)
  883. for scale_factor in scale_factors
  884. ]
  885. _bboxes = [
  886. det_bboxes[i][:, :4] *
  887. scale_factors[i] if rescale else det_bboxes[i][:, :4]
  888. for i in range(len(det_bboxes))
  889. ]
  890. mask_preds = self.forward(feats[0], det_coeffs, _bboxes, img_metas)
  891. # apply mask post-processing to each image individually
  892. segm_results = []
  893. for i in range(num_imgs):
  894. if det_bboxes[i].shape[0] == 0:
  895. segm_results.append([[] for _ in range(self.num_classes)])
  896. else:
  897. segm_result = self.get_seg_masks(mask_preds[i],
  898. det_labels[i],
  899. img_metas[i], rescale)
  900. segm_results.append(segm_result)
  901. return segm_results
  902. class InterpolateModule(BaseModule):
  903. """This is a module version of F.interpolate.
  904. Any arguments you give it just get passed along for the ride.
  905. """
  906. def __init__(self, *args, init_cfg=None, **kwargs):
  907. super().__init__(init_cfg)
  908. self.args = args
  909. self.kwargs = kwargs
  910. def forward(self, x):
  911. """Forward features from the upstream network."""
  912. return F.interpolate(x, *self.args, **self.kwargs)

No Description

Contributors (3)