You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

test_custom_dataset.py 4.8 kB

2 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139
  1. # Copyright (c) OpenMMLab. All rights reserved.
  2. import os
  3. import unittest
  4. from unittest.mock import MagicMock, patch
  5. import pytest
  6. from mmdet.datasets import DATASETS
  7. @patch('mmdet.datasets.CocoDataset.load_annotations', MagicMock())
  8. @patch('mmdet.datasets.CustomDataset.load_annotations', MagicMock())
  9. @patch('mmdet.datasets.XMLDataset.load_annotations', MagicMock())
  10. @patch('mmdet.datasets.CityscapesDataset.load_annotations', MagicMock())
  11. @patch('mmdet.datasets.CocoDataset._filter_imgs', MagicMock)
  12. @patch('mmdet.datasets.CustomDataset._filter_imgs', MagicMock)
  13. @patch('mmdet.datasets.XMLDataset._filter_imgs', MagicMock)
  14. @patch('mmdet.datasets.CityscapesDataset._filter_imgs', MagicMock)
  15. @pytest.mark.parametrize('dataset',
  16. ['CocoDataset', 'VOCDataset', 'CityscapesDataset'])
  17. def test_custom_classes_override_default(dataset):
  18. dataset_class = DATASETS.get(dataset)
  19. if dataset in ['CocoDataset', 'CityscapesDataset']:
  20. dataset_class.coco = MagicMock()
  21. dataset_class.cat_ids = MagicMock()
  22. original_classes = dataset_class.CLASSES
  23. # Test setting classes as a tuple
  24. custom_dataset = dataset_class(
  25. ann_file=MagicMock(),
  26. pipeline=[],
  27. classes=('bus', 'car'),
  28. test_mode=True,
  29. img_prefix='VOC2007' if dataset == 'VOCDataset' else '')
  30. assert custom_dataset.CLASSES != original_classes
  31. assert custom_dataset.CLASSES == ('bus', 'car')
  32. print(custom_dataset)
  33. # Test setting classes as a list
  34. custom_dataset = dataset_class(
  35. ann_file=MagicMock(),
  36. pipeline=[],
  37. classes=['bus', 'car'],
  38. test_mode=True,
  39. img_prefix='VOC2007' if dataset == 'VOCDataset' else '')
  40. assert custom_dataset.CLASSES != original_classes
  41. assert custom_dataset.CLASSES == ['bus', 'car']
  42. print(custom_dataset)
  43. # Test overriding not a subset
  44. custom_dataset = dataset_class(
  45. ann_file=MagicMock(),
  46. pipeline=[],
  47. classes=['foo'],
  48. test_mode=True,
  49. img_prefix='VOC2007' if dataset == 'VOCDataset' else '')
  50. assert custom_dataset.CLASSES != original_classes
  51. assert custom_dataset.CLASSES == ['foo']
  52. print(custom_dataset)
  53. # Test default behavior
  54. custom_dataset = dataset_class(
  55. ann_file=MagicMock(),
  56. pipeline=[],
  57. classes=None,
  58. test_mode=True,
  59. img_prefix='VOC2007' if dataset == 'VOCDataset' else '')
  60. assert custom_dataset.CLASSES == original_classes
  61. print(custom_dataset)
  62. # Test sending file path
  63. import tempfile
  64. tmp_file = tempfile.NamedTemporaryFile()
  65. with open(tmp_file.name, 'w') as f:
  66. f.write('bus\ncar\n')
  67. custom_dataset = dataset_class(
  68. ann_file=MagicMock(),
  69. pipeline=[],
  70. classes=tmp_file.name,
  71. test_mode=True,
  72. img_prefix='VOC2007' if dataset == 'VOCDataset' else '')
  73. tmp_file.close()
  74. assert custom_dataset.CLASSES != original_classes
  75. assert custom_dataset.CLASSES == ['bus', 'car']
  76. print(custom_dataset)
  77. class CustomDatasetTests(unittest.TestCase):
  78. def setUp(self):
  79. super().setUp()
  80. self.data_dir = os.path.join(
  81. os.path.dirname(os.path.dirname(os.path.dirname(__file__))),
  82. 'data')
  83. self.dataset_class = DATASETS.get('XMLDataset')
  84. def test_data_infos__default_db_directories(self):
  85. """Test correct data read having a Pacal-VOC directory structure."""
  86. test_dataset_root = os.path.join(self.data_dir, 'VOCdevkit', 'VOC2007')
  87. custom_ds = self.dataset_class(
  88. data_root=test_dataset_root,
  89. ann_file=os.path.join(test_dataset_root, 'ImageSets', 'Main',
  90. 'trainval.txt'),
  91. pipeline=[],
  92. classes=('person', 'dog'),
  93. test_mode=True)
  94. self.assertListEqual([{
  95. 'id': '000001',
  96. 'filename': 'JPEGImages/000001.jpg',
  97. 'width': 353,
  98. 'height': 500
  99. }], custom_ds.data_infos)
  100. def test_data_infos__overridden_db_subdirectories(self):
  101. """Test correct data read having a customized directory structure."""
  102. test_dataset_root = os.path.join(self.data_dir, 'custom_dataset')
  103. custom_ds = self.dataset_class(
  104. data_root=test_dataset_root,
  105. ann_file=os.path.join(test_dataset_root, 'trainval.txt'),
  106. pipeline=[],
  107. classes=('person', 'dog'),
  108. test_mode=True,
  109. img_prefix='',
  110. img_subdir='images',
  111. ann_subdir='images')
  112. self.assertListEqual([{
  113. 'id': '000001',
  114. 'filename': 'images/000001.jpg',
  115. 'width': 353,
  116. 'height': 500
  117. }], custom_ds.data_infos)

No Description

Contributors (2)