You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

test_unrepeated_sampler.py 3.7 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104
  1. from itertools import chain
  2. import pytest
  3. from fastNLP.core.samplers import UnrepeatedRandomSampler, UnrepeatedSortedSampler, UnrepeatedSequentialSampler
  4. class DatasetWithVaryLength:
  5. def __init__(self, num_of_data=100):
  6. self.data = list(range(num_of_data))
  7. def __getitem__(self, item):
  8. return self.data[item]
  9. def __len__(self):
  10. return len(self.data)
  11. class TestUnrepeatedSampler:
  12. @pytest.mark.parametrize('shuffle', [True, False])
  13. def test_single(self, shuffle):
  14. num_of_data = 100
  15. data = DatasetWithVaryLength(num_of_data)
  16. sampler = UnrepeatedRandomSampler(data, shuffle)
  17. indexes = set(sampler)
  18. assert indexes==set(range(num_of_data))
  19. @pytest.mark.parametrize('num_replicas', [2, 3])
  20. @pytest.mark.parametrize('num_of_data', [2, 3, 4, 100])
  21. @pytest.mark.parametrize('shuffle', [False, True])
  22. def test_multi(self, num_replica, num_of_data, shuffle):
  23. data = DatasetWithVaryLength(num_of_data=num_of_data)
  24. samplers = []
  25. for i in range(num_replica):
  26. sampler = UnrepeatedRandomSampler(dataset=data, shuffle=shuffle)
  27. sampler.set_distributed(num_replica, rank=i)
  28. samplers.append(sampler)
  29. indexes = list(chain(*samplers))
  30. assert len(indexes) == num_of_data
  31. indexes = set(indexes)
  32. assert indexes==set(range(num_of_data))
  33. class TestUnrepeatedSortedSampler:
  34. def test_single(self):
  35. num_of_data = 100
  36. data = DatasetWithVaryLength(num_of_data)
  37. sampler = UnrepeatedSortedSampler(data, length=data.data)
  38. indexes = list(sampler)
  39. assert indexes==list(range(num_of_data-1, -1, -1))
  40. @pytest.mark.parametrize('num_replicas', [2, 3])
  41. @pytest.mark.parametrize('num_of_data', [2, 3, 4, 100])
  42. def test_multi(self, num_replica, num_of_data):
  43. data = DatasetWithVaryLength(num_of_data=num_of_data)
  44. samplers = []
  45. for i in range(num_replica):
  46. sampler = UnrepeatedSortedSampler(dataset=data, length=data.data)
  47. sampler.set_distributed(num_replica, rank=i)
  48. samplers.append(sampler)
  49. # 保证顺序是没乱的
  50. for sampler in samplers:
  51. prev_index = float('inf')
  52. for index in sampler:
  53. assert index <= prev_index
  54. prev_index = index
  55. indexes = list(chain(*samplers))
  56. assert len(indexes) == num_of_data # 不同卡之间没有交叉
  57. indexes = set(indexes)
  58. assert indexes==set(range(num_of_data))
  59. class TestUnrepeatedSequentialSampler:
  60. def test_single(self):
  61. num_of_data = 100
  62. data = DatasetWithVaryLength(num_of_data)
  63. sampler = UnrepeatedSequentialSampler(data, length=data.data)
  64. indexes = list(sampler)
  65. assert indexes==list(range(num_of_data))
  66. @pytest.mark.parametrize('num_replicas', [2, 3])
  67. @pytest.mark.parametrize('num_of_data', [2, 3, 4, 100])
  68. def test_multi(self, num_replica, num_of_data):
  69. data = DatasetWithVaryLength(num_of_data=num_of_data)
  70. samplers = []
  71. for i in range(num_replica):
  72. sampler = UnrepeatedSequentialSampler(dataset=data, length=data.data)
  73. sampler.set_distributed(num_replica, rank=i)
  74. samplers.append(sampler)
  75. # 保证顺序是没乱的
  76. for sampler in samplers:
  77. prev_index = float('-inf')
  78. for index in sampler:
  79. assert index>=prev_index
  80. prev_index = index
  81. indexes = list(chain(*samplers))
  82. assert len(indexes) == num_of_data
  83. indexes = set(indexes)
  84. assert indexes == set(range(num_of_data))