You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

utils.py 6.5 kB

2 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164
  1. # Copyright (c) OpenMMLab. All rights reserved.
  2. import copy
  3. import warnings
  4. from mmcv.cnn import VGG
  5. from mmcv.runner.hooks import HOOKS, Hook
  6. from mmdet.datasets.builder import PIPELINES
  7. from mmdet.datasets.pipelines import LoadAnnotations, LoadImageFromFile
  8. from mmdet.models.dense_heads import GARPNHead, RPNHead
  9. from mmdet.models.roi_heads.mask_heads import FusedSemanticHead
  10. def replace_ImageToTensor(pipelines):
  11. """Replace the ImageToTensor transform in a data pipeline to
  12. DefaultFormatBundle, which is normally useful in batch inference.
  13. Args:
  14. pipelines (list[dict]): Data pipeline configs.
  15. Returns:
  16. list: The new pipeline list with all ImageToTensor replaced by
  17. DefaultFormatBundle.
  18. Examples:
  19. >>> pipelines = [
  20. ... dict(type='LoadImageFromFile'),
  21. ... dict(
  22. ... type='MultiScaleFlipAug',
  23. ... img_scale=(1333, 800),
  24. ... flip=False,
  25. ... transforms=[
  26. ... dict(type='Resize', keep_ratio=True),
  27. ... dict(type='RandomFlip'),
  28. ... dict(type='Normalize', mean=[0, 0, 0], std=[1, 1, 1]),
  29. ... dict(type='Pad', size_divisor=32),
  30. ... dict(type='ImageToTensor', keys=['img']),
  31. ... dict(type='Collect', keys=['img']),
  32. ... ])
  33. ... ]
  34. >>> expected_pipelines = [
  35. ... dict(type='LoadImageFromFile'),
  36. ... dict(
  37. ... type='MultiScaleFlipAug',
  38. ... img_scale=(1333, 800),
  39. ... flip=False,
  40. ... transforms=[
  41. ... dict(type='Resize', keep_ratio=True),
  42. ... dict(type='RandomFlip'),
  43. ... dict(type='Normalize', mean=[0, 0, 0], std=[1, 1, 1]),
  44. ... dict(type='Pad', size_divisor=32),
  45. ... dict(type='DefaultFormatBundle'),
  46. ... dict(type='Collect', keys=['img']),
  47. ... ])
  48. ... ]
  49. >>> assert expected_pipelines == replace_ImageToTensor(pipelines)
  50. """
  51. pipelines = copy.deepcopy(pipelines)
  52. for i, pipeline in enumerate(pipelines):
  53. if pipeline['type'] == 'MultiScaleFlipAug':
  54. assert 'transforms' in pipeline
  55. pipeline['transforms'] = replace_ImageToTensor(
  56. pipeline['transforms'])
  57. elif pipeline['type'] == 'ImageToTensor':
  58. warnings.warn(
  59. '"ImageToTensor" pipeline is replaced by '
  60. '"DefaultFormatBundle" for batch inference. It is '
  61. 'recommended to manually replace it in the test '
  62. 'data pipeline in your config file.', UserWarning)
  63. pipelines[i] = {'type': 'DefaultFormatBundle'}
  64. return pipelines
  65. def get_loading_pipeline(pipeline):
  66. """Only keep loading image and annotations related configuration.
  67. Args:
  68. pipeline (list[dict]): Data pipeline configs.
  69. Returns:
  70. list[dict]: The new pipeline list with only keep
  71. loading image and annotations related configuration.
  72. Examples:
  73. >>> pipelines = [
  74. ... dict(type='LoadImageFromFile'),
  75. ... dict(type='LoadAnnotations', with_bbox=True),
  76. ... dict(type='Resize', img_scale=(1333, 800), keep_ratio=True),
  77. ... dict(type='RandomFlip', flip_ratio=0.5),
  78. ... dict(type='Normalize', **img_norm_cfg),
  79. ... dict(type='Pad', size_divisor=32),
  80. ... dict(type='DefaultFormatBundle'),
  81. ... dict(type='Collect', keys=['img', 'gt_bboxes', 'gt_labels'])
  82. ... ]
  83. >>> expected_pipelines = [
  84. ... dict(type='LoadImageFromFile'),
  85. ... dict(type='LoadAnnotations', with_bbox=True)
  86. ... ]
  87. >>> assert expected_pipelines ==\
  88. ... get_loading_pipeline(pipelines)
  89. """
  90. loading_pipeline_cfg = []
  91. for cfg in pipeline:
  92. obj_cls = PIPELINES.get(cfg['type'])
  93. # TODO:use more elegant way to distinguish loading modules
  94. if obj_cls is not None and obj_cls in (LoadImageFromFile,
  95. LoadAnnotations):
  96. loading_pipeline_cfg.append(cfg)
  97. assert len(loading_pipeline_cfg) == 2, \
  98. 'The data pipeline in your config file must include ' \
  99. 'loading image and annotations related pipeline.'
  100. return loading_pipeline_cfg
  101. @HOOKS.register_module()
  102. class NumClassCheckHook(Hook):
  103. def _check_head(self, runner):
  104. """Check whether the `num_classes` in head matches the length of
  105. `CLASSES` in `dataset`.
  106. Args:
  107. runner (obj:`EpochBasedRunner`): Epoch based Runner.
  108. """
  109. model = runner.model
  110. dataset = runner.data_loader.dataset
  111. if dataset.CLASSES is None:
  112. runner.logger.warning(
  113. f'Please set `CLASSES` '
  114. f'in the {dataset.__class__.__name__} and'
  115. f'check if it is consistent with the `num_classes` '
  116. f'of head')
  117. else:
  118. assert type(dataset.CLASSES) is not str, \
  119. (f'`CLASSES` in {dataset.__class__.__name__}'
  120. f'should be a tuple of str.'
  121. f'Add comma if number of classes is 1 as '
  122. f'CLASSES = ({dataset.CLASSES},)')
  123. for name, module in model.named_modules():
  124. if hasattr(module, 'num_classes') and not isinstance(
  125. module, (RPNHead, VGG, FusedSemanticHead, GARPNHead)):
  126. assert module.num_classes == len(dataset.CLASSES), \
  127. (f'The `num_classes` ({module.num_classes}) in '
  128. f'{module.__class__.__name__} of '
  129. f'{model.__class__.__name__} does not matches '
  130. f'the length of `CLASSES` '
  131. f'{len(dataset.CLASSES)}) in '
  132. f'{dataset.__class__.__name__}')
  133. def before_train_epoch(self, runner):
  134. """Check whether the training dataset is compatible with head.
  135. Args:
  136. runner (obj:`EpochBasedRunner`): Epoch based Runner.
  137. """
  138. self._check_head(runner)
  139. def before_val_epoch(self, runner):
  140. """Check whether the dataset in val epoch is compatible with head.
  141. Args:
  142. runner (obj:`EpochBasedRunner`): Epoch based Runner.
  143. """
  144. self._check_head(runner)

No Description

Contributors (3)