You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

panoptic_fpn_head.py 6.7 kB

2 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155
  1. # Copyright (c) OpenMMLab. All rights reserved.
  2. import warnings
  3. import torch
  4. import torch.nn as nn
  5. from mmcv.runner import ModuleList
  6. from ..builder import HEADS
  7. from ..utils import ConvUpsample
  8. from .base_semantic_head import BaseSemanticHead
  9. @HEADS.register_module()
  10. class PanopticFPNHead(BaseSemanticHead):
  11. """PanopticFPNHead used in Panoptic FPN.
  12. In this head, the number of output channels is ``num_stuff_classes
  13. + 1``, including all stuff classes and one thing class. The stuff
  14. classes will be reset from ``0`` to ``num_stuff_classes - 1``, the
  15. thing classes will be merged to ``num_stuff_classes``-th channel.
  16. Arg:
  17. num_things_classes (int): Number of thing classes. Default: 80.
  18. num_stuff_classes (int): Number of stuff classes. Default: 53.
  19. num_classes (int): Number of classes, including all stuff
  20. classes and one thing class. This argument is deprecated,
  21. please use ``num_things_classes`` and ``num_stuff_classes``.
  22. The module will automatically infer the num_classes by
  23. ``num_stuff_classes + 1``.
  24. in_channels (int): Number of channels in the input feature
  25. map.
  26. inner_channels (int): Number of channels in inner features.
  27. start_level (int): The start level of the input features
  28. used in PanopticFPN.
  29. end_level (int): The end level of the used features, the
  30. ``end_level``-th layer will not be used.
  31. fg_range (tuple): Range of the foreground classes. It starts
  32. from ``0`` to ``num_things_classes-1``. Deprecated, please use
  33. ``num_things_classes`` directly.
  34. bg_range (tuple): Range of the background classes. It starts
  35. from ``num_things_classes`` to ``num_things_classes +
  36. num_stuff_classes - 1``. Deprecated, please use
  37. ``num_stuff_classes`` and ``num_things_classes`` directly.
  38. conv_cfg (dict): Dictionary to construct and config
  39. conv layer. Default: None.
  40. norm_cfg (dict): Dictionary to construct and config norm layer.
  41. Use ``GN`` by default.
  42. init_cfg (dict or list[dict], optional): Initialization config dict.
  43. loss_seg (dict): the loss of the semantic head.
  44. """
  45. def __init__(self,
  46. num_things_classes=80,
  47. num_stuff_classes=53,
  48. num_classes=None,
  49. in_channels=256,
  50. inner_channels=128,
  51. start_level=0,
  52. end_level=4,
  53. fg_range=None,
  54. bg_range=None,
  55. conv_cfg=None,
  56. norm_cfg=dict(type='GN', num_groups=32, requires_grad=True),
  57. init_cfg=None,
  58. loss_seg=dict(
  59. type='CrossEntropyLoss', ignore_index=-1,
  60. loss_weight=1.0)):
  61. if num_classes is not None:
  62. warnings.warn(
  63. '`num_classes` is deprecated now, please set '
  64. '`num_stuff_classes` directly, the `num_classes` will be '
  65. 'set to `num_stuff_classes + 1`')
  66. # num_classes = num_stuff_classes + 1 for PanopticFPN.
  67. assert num_classes == num_stuff_classes + 1
  68. super(PanopticFPNHead, self).__init__(num_stuff_classes + 1, init_cfg,
  69. loss_seg)
  70. self.num_things_classes = num_things_classes
  71. self.num_stuff_classes = num_stuff_classes
  72. if fg_range is not None and bg_range is not None:
  73. self.fg_range = fg_range
  74. self.bg_range = bg_range
  75. self.num_things_classes = fg_range[1] - fg_range[0] + 1
  76. self.num_stuff_classes = bg_range[1] - bg_range[0] + 1
  77. warnings.warn(
  78. '`fg_range` and `bg_range` are deprecated now, '
  79. f'please use `num_things_classes`={self.num_things_classes} '
  80. f'and `num_stuff_classes`={self.num_stuff_classes} instead.')
  81. # Used feature layers are [start_level, end_level)
  82. self.start_level = start_level
  83. self.end_level = end_level
  84. self.num_stages = end_level - start_level
  85. self.inner_channels = inner_channels
  86. self.conv_upsample_layers = ModuleList()
  87. for i in range(start_level, end_level):
  88. self.conv_upsample_layers.append(
  89. ConvUpsample(
  90. in_channels,
  91. inner_channels,
  92. num_layers=i if i > 0 else 1,
  93. num_upsample=i if i > 0 else 0,
  94. conv_cfg=conv_cfg,
  95. norm_cfg=norm_cfg,
  96. ))
  97. self.conv_logits = nn.Conv2d(inner_channels, self.num_classes, 1)
  98. def _set_things_to_void(self, gt_semantic_seg):
  99. """Merge thing classes to one class.
  100. In PanopticFPN, the background labels will be reset from `0` to
  101. `self.num_stuff_classes-1`, the foreground labels will be merged to
  102. `self.num_stuff_classes`-th channel.
  103. """
  104. gt_semantic_seg = gt_semantic_seg.int()
  105. fg_mask = gt_semantic_seg < self.num_things_classes
  106. bg_mask = (gt_semantic_seg >= self.num_things_classes) * (
  107. gt_semantic_seg < self.num_things_classes + self.num_stuff_classes)
  108. new_gt_seg = torch.clone(gt_semantic_seg)
  109. new_gt_seg = torch.where(bg_mask,
  110. gt_semantic_seg - self.num_things_classes,
  111. new_gt_seg)
  112. new_gt_seg = torch.where(fg_mask,
  113. fg_mask.int() * self.num_stuff_classes,
  114. new_gt_seg)
  115. return new_gt_seg
  116. def loss(self, seg_preds, gt_semantic_seg):
  117. """The loss of PanopticFPN head.
  118. Things classes will be merged to one class in PanopticFPN.
  119. """
  120. gt_semantic_seg = self._set_things_to_void(gt_semantic_seg)
  121. return super().loss(seg_preds, gt_semantic_seg)
  122. def init_weights(self):
  123. super().init_weights()
  124. nn.init.normal_(self.conv_logits.weight.data, 0, 0.01)
  125. self.conv_logits.bias.data.zero_()
  126. def forward(self, x):
  127. # the number of subnets must be not more than
  128. # the length of features.
  129. assert self.num_stages <= len(x)
  130. feats = []
  131. for i, layer in enumerate(self.conv_upsample_layers):
  132. f = layer(x[self.start_level + i])
  133. feats.append(f)
  134. feats = torch.sum(torch.stack(feats, dim=0), dim=0)
  135. seg_preds = self.conv_logits(feats)
  136. out = dict(seg_preds=seg_preds, feats=feats)
  137. return out

No Description

Contributors (3)