You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

test_time_aug.py 4.5 kB

2 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121
  1. # Copyright (c) OpenMMLab. All rights reserved.
  2. import warnings
  3. import mmcv
  4. from ..builder import PIPELINES
  5. from .compose import Compose
  6. @PIPELINES.register_module()
  7. class MultiScaleFlipAug:
  8. """Test-time augmentation with multiple scales and flipping.
  9. An example configuration is as followed:
  10. .. code-block::
  11. img_scale=[(1333, 400), (1333, 800)],
  12. flip=True,
  13. transforms=[
  14. dict(type='Resize', keep_ratio=True),
  15. dict(type='RandomFlip'),
  16. dict(type='Normalize', **img_norm_cfg),
  17. dict(type='Pad', size_divisor=32),
  18. dict(type='ImageToTensor', keys=['img']),
  19. dict(type='Collect', keys=['img']),
  20. ]
  21. After MultiScaleFLipAug with above configuration, the results are wrapped
  22. into lists of the same length as followed:
  23. .. code-block::
  24. dict(
  25. img=[...],
  26. img_shape=[...],
  27. scale=[(1333, 400), (1333, 400), (1333, 800), (1333, 800)]
  28. flip=[False, True, False, True]
  29. ...
  30. )
  31. Args:
  32. transforms (list[dict]): Transforms to apply in each augmentation.
  33. img_scale (tuple | list[tuple] | None): Images scales for resizing.
  34. scale_factor (float | list[float] | None): Scale factors for resizing.
  35. flip (bool): Whether apply flip augmentation. Default: False.
  36. flip_direction (str | list[str]): Flip augmentation directions,
  37. options are "horizontal", "vertical" and "diagonal". If
  38. flip_direction is a list, multiple flip augmentations will be
  39. applied. It has no effect when flip == False. Default:
  40. "horizontal".
  41. """
  42. def __init__(self,
  43. transforms,
  44. img_scale=None,
  45. scale_factor=None,
  46. flip=False,
  47. flip_direction='horizontal'):
  48. self.transforms = Compose(transforms)
  49. assert (img_scale is None) ^ (scale_factor is None), (
  50. 'Must have but only one variable can be set')
  51. if img_scale is not None:
  52. self.img_scale = img_scale if isinstance(img_scale,
  53. list) else [img_scale]
  54. self.scale_key = 'scale'
  55. assert mmcv.is_list_of(self.img_scale, tuple)
  56. else:
  57. self.img_scale = scale_factor if isinstance(
  58. scale_factor, list) else [scale_factor]
  59. self.scale_key = 'scale_factor'
  60. self.flip = flip
  61. self.flip_direction = flip_direction if isinstance(
  62. flip_direction, list) else [flip_direction]
  63. assert mmcv.is_list_of(self.flip_direction, str)
  64. if not self.flip and self.flip_direction != ['horizontal']:
  65. warnings.warn(
  66. 'flip_direction has no effect when flip is set to False')
  67. if (self.flip
  68. and not any([t['type'] == 'RandomFlip' for t in transforms])):
  69. warnings.warn(
  70. 'flip has no effect when RandomFlip is not in transforms')
  71. def __call__(self, results):
  72. """Call function to apply test time augment transforms on results.
  73. Args:l
  74. results (dict): Result dict contains the data to transform.
  75. Returns:
  76. dict[str: list]: The augmented data, where each value is wrapped
  77. into a list.
  78. """
  79. aug_data = []
  80. flip_args = [(False, None)]
  81. if self.flip:
  82. flip_args += [(True, direction)
  83. for direction in self.flip_direction]
  84. for scale in self.img_scale:
  85. for flip, direction in flip_args:
  86. _results = results.copy()
  87. _results[self.scale_key] = scale
  88. _results['flip'] = flip
  89. _results['flip_direction'] = direction
  90. data = self.transforms(_results)
  91. aug_data.append(data)
  92. # list of dict to dict of list
  93. aug_data_dict = {key: [] for key in aug_data[0]}
  94. for data in aug_data:
  95. for key, val in data.items():
  96. aug_data_dict[key].append(val)
  97. return aug_data_dict
  98. def __repr__(self):
  99. repr_str = self.__class__.__name__
  100. repr_str += f'(transforms={self.transforms}, '
  101. repr_str += f'img_scale={self.img_scale}, flip={self.flip}, '
  102. repr_str += f'flip_direction={self.flip_direction})'
  103. return repr_str

No Description

Contributors (2)