You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

instaboost.py 4.5 kB

2 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118
  1. # Copyright (c) OpenMMLab. All rights reserved.
  2. import numpy as np
  3. from ..builder import PIPELINES
  4. @PIPELINES.register_module()
  5. class InstaBoost:
  6. r"""Data augmentation method in `InstaBoost: Boosting Instance
  7. Segmentation Via Probability Map Guided Copy-Pasting
  8. <https://arxiv.org/abs/1908.07801>`_.
  9. Refer to https://github.com/GothicAi/Instaboost for implementation details.
  10. Args:
  11. action_candidate (tuple): Action candidates. "normal", "horizontal", \
  12. "vertical", "skip" are supported. Default: ('normal', \
  13. 'horizontal', 'skip').
  14. action_prob (tuple): Corresponding action probabilities. Should be \
  15. the same length as action_candidate. Default: (1, 0, 0).
  16. scale (tuple): (min scale, max scale). Default: (0.8, 1.2).
  17. dx (int): The maximum x-axis shift will be (instance width) / dx.
  18. Default 15.
  19. dy (int): The maximum y-axis shift will be (instance height) / dy.
  20. Default 15.
  21. theta (tuple): (min rotation degree, max rotation degree). \
  22. Default: (-1, 1).
  23. color_prob (float): Probability of images for color augmentation.
  24. Default 0.5.
  25. heatmap_flag (bool): Whether to use heatmap guided. Default False.
  26. aug_ratio (float): Probability of applying this transformation. \
  27. Default 0.5.
  28. """
  29. def __init__(self,
  30. action_candidate=('normal', 'horizontal', 'skip'),
  31. action_prob=(1, 0, 0),
  32. scale=(0.8, 1.2),
  33. dx=15,
  34. dy=15,
  35. theta=(-1, 1),
  36. color_prob=0.5,
  37. hflag=False,
  38. aug_ratio=0.5):
  39. try:
  40. import instaboostfast as instaboost
  41. except ImportError:
  42. raise ImportError(
  43. 'Please run "pip install instaboostfast" '
  44. 'to install instaboostfast first for instaboost augmentation.')
  45. self.cfg = instaboost.InstaBoostConfig(action_candidate, action_prob,
  46. scale, dx, dy, theta,
  47. color_prob, hflag)
  48. self.aug_ratio = aug_ratio
  49. def _load_anns(self, results):
  50. labels = results['ann_info']['labels']
  51. masks = results['ann_info']['masks']
  52. bboxes = results['ann_info']['bboxes']
  53. n = len(labels)
  54. anns = []
  55. for i in range(n):
  56. label = labels[i]
  57. bbox = bboxes[i]
  58. mask = masks[i]
  59. x1, y1, x2, y2 = bbox
  60. # assert (x2 - x1) >= 1 and (y2 - y1) >= 1
  61. bbox = [x1, y1, x2 - x1, y2 - y1]
  62. anns.append({
  63. 'category_id': label,
  64. 'segmentation': mask,
  65. 'bbox': bbox
  66. })
  67. return anns
  68. def _parse_anns(self, results, anns, img):
  69. gt_bboxes = []
  70. gt_labels = []
  71. gt_masks_ann = []
  72. for ann in anns:
  73. x1, y1, w, h = ann['bbox']
  74. # TODO: more essential bug need to be fixed in instaboost
  75. if w <= 0 or h <= 0:
  76. continue
  77. bbox = [x1, y1, x1 + w, y1 + h]
  78. gt_bboxes.append(bbox)
  79. gt_labels.append(ann['category_id'])
  80. gt_masks_ann.append(ann['segmentation'])
  81. gt_bboxes = np.array(gt_bboxes, dtype=np.float32)
  82. gt_labels = np.array(gt_labels, dtype=np.int64)
  83. results['ann_info']['labels'] = gt_labels
  84. results['ann_info']['bboxes'] = gt_bboxes
  85. results['ann_info']['masks'] = gt_masks_ann
  86. results['img'] = img
  87. return results
  88. def __call__(self, results):
  89. img = results['img']
  90. orig_type = img.dtype
  91. anns = self._load_anns(results)
  92. if np.random.choice([0, 1], p=[1 - self.aug_ratio, self.aug_ratio]):
  93. try:
  94. import instaboostfast as instaboost
  95. except ImportError:
  96. raise ImportError('Please run "pip install instaboostfast" '
  97. 'to install instaboostfast first.')
  98. anns, img = instaboost.get_new_data(
  99. anns, img.astype(np.uint8), self.cfg, background=None)
  100. results = self._parse_anns(results, anns, img.astype(orig_type))
  101. return results
  102. def __repr__(self):
  103. repr_str = self.__class__.__name__
  104. repr_str += f'(cfg={self.cfg}, aug_ratio={self.aug_ratio})'
  105. return repr_str

No Description

Contributors (3)