You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

browse_dataset.py 3.5 kB

2 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109
  1. # Copyright (c) OpenMMLab. All rights reserved.
  2. import argparse
  3. import os
  4. from collections import Sequence
  5. from pathlib import Path
  6. import mmcv
  7. from mmcv import Config, DictAction
  8. from mmdet.core.utils import mask2ndarray
  9. from mmdet.core.visualization import imshow_det_bboxes
  10. from mmdet.datasets.builder import build_dataset
  11. def parse_args():
  12. parser = argparse.ArgumentParser(description='Browse a dataset')
  13. parser.add_argument('config', help='train config file path')
  14. parser.add_argument(
  15. '--skip-type',
  16. type=str,
  17. nargs='+',
  18. default=['DefaultFormatBundle', 'Normalize', 'Collect'],
  19. help='skip some useless pipeline')
  20. parser.add_argument(
  21. '--output-dir',
  22. default=None,
  23. type=str,
  24. help='If there is no display interface, you can save it')
  25. parser.add_argument('--not-show', default=False, action='store_true')
  26. parser.add_argument(
  27. '--show-interval',
  28. type=float,
  29. default=2,
  30. help='the interval of show (s)')
  31. parser.add_argument(
  32. '--cfg-options',
  33. nargs='+',
  34. action=DictAction,
  35. help='override some settings in the used config, the key-value pair '
  36. 'in xxx=yyy format will be merged into config file. If the value to '
  37. 'be overwritten is a list, it should be like key="[a,b]" or key=a,b '
  38. 'It also allows nested list/tuple values, e.g. key="[(a,b),(c,d)]" '
  39. 'Note that the quotation marks are necessary and that no white space '
  40. 'is allowed.')
  41. args = parser.parse_args()
  42. return args
  43. def retrieve_data_cfg(config_path, skip_type, cfg_options):
  44. def skip_pipeline_steps(config):
  45. config['pipeline'] = [
  46. x for x in config.pipeline if x['type'] not in skip_type
  47. ]
  48. cfg = Config.fromfile(config_path)
  49. if cfg_options is not None:
  50. cfg.merge_from_dict(cfg_options)
  51. # import modules from string list.
  52. if cfg.get('custom_imports', None):
  53. from mmcv.utils import import_modules_from_strings
  54. import_modules_from_strings(**cfg['custom_imports'])
  55. train_data_cfg = cfg.data.train
  56. while 'dataset' in train_data_cfg and train_data_cfg[
  57. 'type'] != 'MultiImageMixDataset':
  58. train_data_cfg = train_data_cfg['dataset']
  59. if isinstance(train_data_cfg, Sequence):
  60. [skip_pipeline_steps(c) for c in train_data_cfg]
  61. else:
  62. skip_pipeline_steps(train_data_cfg)
  63. return cfg
  64. def main():
  65. args = parse_args()
  66. cfg = retrieve_data_cfg(args.config, args.skip_type, args.cfg_options)
  67. dataset = build_dataset(cfg.data.train)
  68. progress_bar = mmcv.ProgressBar(len(dataset))
  69. for item in dataset:
  70. filename = os.path.join(args.output_dir,
  71. Path(item['filename']).name
  72. ) if args.output_dir is not None else None
  73. gt_masks = item.get('gt_masks', None)
  74. if gt_masks is not None:
  75. gt_masks = mask2ndarray(gt_masks)
  76. imshow_det_bboxes(
  77. item['img'],
  78. item['gt_bboxes'],
  79. item['gt_labels'],
  80. gt_masks,
  81. class_names=dataset.CLASSES,
  82. show=not args.not_show,
  83. wait_time=args.show_interval,
  84. out_file=filename,
  85. bbox_color=(255, 102, 61),
  86. text_color=(255, 102, 61))
  87. progress_bar.update()
  88. if __name__ == '__main__':
  89. main()

No Description

Contributors (1)