You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

base_mask_head.py 4.5 kB

2 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116
  1. # Copyright (c) OpenMMLab. All rights reserved.
  2. from abc import ABCMeta, abstractmethod
  3. from mmcv.runner import BaseModule
  4. class BaseMaskHead(BaseModule, metaclass=ABCMeta):
  5. """Base class for mask heads used in One-Stage Instance Segmentation."""
  6. def __init__(self, init_cfg):
  7. super(BaseMaskHead, self).__init__(init_cfg)
  8. @abstractmethod
  9. def loss(self, **kwargs):
  10. pass
  11. @abstractmethod
  12. def get_results(self, **kwargs):
  13. """Get precessed :obj:`InstanceData` of multiple images."""
  14. pass
  15. def forward_train(self,
  16. x,
  17. gt_labels,
  18. gt_masks,
  19. img_metas,
  20. gt_bboxes=None,
  21. gt_bboxes_ignore=None,
  22. positive_infos=None,
  23. **kwargs):
  24. """
  25. Args:
  26. x (list[Tensor] | tuple[Tensor]): Features from FPN.
  27. Each has a shape (B, C, H, W).
  28. gt_labels (list[Tensor]): Ground truth labels of all images.
  29. each has a shape (num_gts,).
  30. gt_masks (list[Tensor]) : Masks for each bbox, has a shape
  31. (num_gts, h , w).
  32. img_metas (list[dict]): Meta information of each image, e.g.,
  33. image size, scaling factor, etc.
  34. gt_bboxes (list[Tensor]): Ground truth bboxes of the image,
  35. each item has a shape (num_gts, 4).
  36. gt_bboxes_ignore (list[Tensor], None): Ground truth bboxes to be
  37. ignored, each item has a shape (num_ignored_gts, 4).
  38. positive_infos (list[:obj:`InstanceData`], optional): Information
  39. of positive samples. Used when the label assignment is
  40. done outside the MaskHead, e.g., in BboxHead in
  41. YOLACT or CondInst, etc. When the label assignment is done in
  42. MaskHead, it would be None, like SOLO. All values
  43. in it should have shape (num_positive_samples, *).
  44. Returns:
  45. dict[str, Tensor]: A dictionary of loss components.
  46. """
  47. if positive_infos is None:
  48. outs = self(x)
  49. else:
  50. outs = self(x, positive_infos)
  51. assert isinstance(outs, tuple), 'Forward results should be a tuple, ' \
  52. 'even if only one item is returned'
  53. loss = self.loss(
  54. *outs,
  55. gt_labels=gt_labels,
  56. gt_masks=gt_masks,
  57. img_metas=img_metas,
  58. gt_bboxes=gt_bboxes,
  59. gt_bboxes_ignore=gt_bboxes_ignore,
  60. positive_infos=positive_infos,
  61. **kwargs)
  62. return loss
  63. def simple_test(self,
  64. feats,
  65. img_metas,
  66. rescale=False,
  67. instances_list=None,
  68. **kwargs):
  69. """Test function without test-time augmentation.
  70. Args:
  71. feats (tuple[torch.Tensor]): Multi-level features from the
  72. upstream network, each is a 4D-tensor.
  73. img_metas (list[dict]): List of image information.
  74. rescale (bool, optional): Whether to rescale the results.
  75. Defaults to False.
  76. instances_list (list[obj:`InstanceData`], optional): Detection
  77. results of each image after the post process. Only exist
  78. if there is a `bbox_head`, like `YOLACT`, `CondInst`, etc.
  79. Returns:
  80. list[obj:`InstanceData`]: Instance segmentation \
  81. results of each image after the post process. \
  82. Each item usually contains following keys. \
  83. - scores (Tensor): Classification scores, has a shape
  84. (num_instance,)
  85. - labels (Tensor): Has a shape (num_instances,).
  86. - masks (Tensor): Processed mask results, has a
  87. shape (num_instances, h, w).
  88. """
  89. if instances_list is None:
  90. outs = self(feats)
  91. else:
  92. outs = self(feats, instances_list=instances_list)
  93. mask_inputs = outs + (img_metas, )
  94. results_list = self.get_results(
  95. *mask_inputs,
  96. rescale=rescale,
  97. instances_list=instances_list,
  98. **kwargs)
  99. return results_list
  100. def onnx_export(self, img, img_metas):
  101. raise NotImplementedError(f'{self.__class__.__name__} does '
  102. f'not support ONNX EXPORT')

No Description

Contributors (3)