You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

test_dataset_wrapper.py 5.4 kB

2 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142
  1. # Copyright (c) OpenMMLab. All rights reserved.
  2. import bisect
  3. import math
  4. from collections import defaultdict
  5. from unittest.mock import MagicMock
  6. import numpy as np
  7. from mmdet.datasets import (ClassBalancedDataset, ConcatDataset, CustomDataset,
  8. MultiImageMixDataset, RepeatDataset)
  9. def test_dataset_wrapper():
  10. CustomDataset.load_annotations = MagicMock()
  11. CustomDataset.__getitem__ = MagicMock(side_effect=lambda idx: idx)
  12. dataset_a = CustomDataset(
  13. ann_file=MagicMock(), pipeline=[], test_mode=True, img_prefix='')
  14. len_a = 10
  15. cat_ids_list_a = [
  16. np.random.randint(0, 80, num).tolist()
  17. for num in np.random.randint(1, 20, len_a)
  18. ]
  19. dataset_a.data_infos = MagicMock()
  20. dataset_a.data_infos.__len__.return_value = len_a
  21. dataset_a.get_cat_ids = MagicMock(
  22. side_effect=lambda idx: cat_ids_list_a[idx])
  23. dataset_b = CustomDataset(
  24. ann_file=MagicMock(), pipeline=[], test_mode=True, img_prefix='')
  25. len_b = 20
  26. cat_ids_list_b = [
  27. np.random.randint(0, 80, num).tolist()
  28. for num in np.random.randint(1, 20, len_b)
  29. ]
  30. dataset_b.data_infos = MagicMock()
  31. dataset_b.data_infos.__len__.return_value = len_b
  32. dataset_b.get_cat_ids = MagicMock(
  33. side_effect=lambda idx: cat_ids_list_b[idx])
  34. concat_dataset = ConcatDataset([dataset_a, dataset_b])
  35. assert concat_dataset[5] == 5
  36. assert concat_dataset[25] == 15
  37. assert concat_dataset.get_cat_ids(5) == cat_ids_list_a[5]
  38. assert concat_dataset.get_cat_ids(25) == cat_ids_list_b[15]
  39. assert len(concat_dataset) == len(dataset_a) + len(dataset_b)
  40. repeat_dataset = RepeatDataset(dataset_a, 10)
  41. assert repeat_dataset[5] == 5
  42. assert repeat_dataset[15] == 5
  43. assert repeat_dataset[27] == 7
  44. assert repeat_dataset.get_cat_ids(5) == cat_ids_list_a[5]
  45. assert repeat_dataset.get_cat_ids(15) == cat_ids_list_a[5]
  46. assert repeat_dataset.get_cat_ids(27) == cat_ids_list_a[7]
  47. assert len(repeat_dataset) == 10 * len(dataset_a)
  48. category_freq = defaultdict(int)
  49. for cat_ids in cat_ids_list_a:
  50. cat_ids = set(cat_ids)
  51. for cat_id in cat_ids:
  52. category_freq[cat_id] += 1
  53. for k, v in category_freq.items():
  54. category_freq[k] = v / len(cat_ids_list_a)
  55. mean_freq = np.mean(list(category_freq.values()))
  56. repeat_thr = mean_freq
  57. category_repeat = {
  58. cat_id: max(1.0, math.sqrt(repeat_thr / cat_freq))
  59. for cat_id, cat_freq in category_freq.items()
  60. }
  61. repeat_factors = []
  62. for cat_ids in cat_ids_list_a:
  63. cat_ids = set(cat_ids)
  64. repeat_factor = max({category_repeat[cat_id] for cat_id in cat_ids})
  65. repeat_factors.append(math.ceil(repeat_factor))
  66. repeat_factors_cumsum = np.cumsum(repeat_factors)
  67. repeat_factor_dataset = ClassBalancedDataset(dataset_a, repeat_thr)
  68. assert len(repeat_factor_dataset) == repeat_factors_cumsum[-1]
  69. for idx in np.random.randint(0, len(repeat_factor_dataset), 3):
  70. assert repeat_factor_dataset[idx] == bisect.bisect_right(
  71. repeat_factors_cumsum, idx)
  72. img_scale = (60, 60)
  73. dynamic_scale = (80, 80)
  74. pipeline = [
  75. dict(type='Mosaic', img_scale=img_scale, pad_val=114.0),
  76. dict(
  77. type='RandomAffine',
  78. scaling_ratio_range=(0.1, 2),
  79. border=(-img_scale[0] // 2, -img_scale[1] // 2)),
  80. dict(
  81. type='MixUp',
  82. img_scale=img_scale,
  83. ratio_range=(0.8, 1.6),
  84. pad_val=114.0),
  85. dict(type='RandomFlip', flip_ratio=0.5),
  86. dict(type='Resize', keep_ratio=True),
  87. dict(type='Pad', pad_to_square=True, pad_val=114.0),
  88. ]
  89. CustomDataset.load_annotations = MagicMock()
  90. results = []
  91. for _ in range(2):
  92. height = np.random.randint(10, 30)
  93. weight = np.random.randint(10, 30)
  94. img = np.ones((height, weight, 3))
  95. gt_bbox = np.concatenate([
  96. np.random.randint(1, 5, (2, 2)),
  97. np.random.randint(1, 5, (2, 2)) + 5
  98. ],
  99. axis=1)
  100. gt_labels = np.random.randint(0, 80, 2)
  101. results.append(dict(gt_bboxes=gt_bbox, gt_labels=gt_labels, img=img))
  102. CustomDataset.__getitem__ = MagicMock(side_effect=lambda idx: results[idx])
  103. dataset_a = CustomDataset(
  104. ann_file=MagicMock(), pipeline=[], test_mode=True, img_prefix='')
  105. len_a = 2
  106. cat_ids_list_a = [
  107. np.random.randint(0, 80, num).tolist()
  108. for num in np.random.randint(1, 20, len_a)
  109. ]
  110. dataset_a.data_infos = MagicMock()
  111. dataset_a.data_infos.__len__.return_value = len_a
  112. dataset_a.get_cat_ids = MagicMock(
  113. side_effect=lambda idx: cat_ids_list_a[idx])
  114. multi_image_mix_dataset = MultiImageMixDataset(dataset_a, pipeline,
  115. dynamic_scale)
  116. for idx in range(len_a):
  117. results_ = multi_image_mix_dataset[idx]
  118. assert results_['img'].shape == (dynamic_scale[0], dynamic_scale[1], 3)
  119. # test skip_type_keys
  120. multi_image_mix_dataset = MultiImageMixDataset(
  121. dataset_a,
  122. pipeline,
  123. dynamic_scale,
  124. skip_type_keys=('MixUp', 'RandomFlip', 'Resize', 'Pad'))
  125. for idx in range(len_a):
  126. results_ = multi_image_mix_dataset[idx]
  127. assert results_['img'].shape == (img_scale[0], img_scale[1], 3)

No Description

Contributors (2)