You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

test_reproducible_sampler.py 30 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745
  1. import numpy as np
  2. import pytest
  3. from functools import partial
  4. from itertools import chain
  5. from copy import deepcopy
  6. from fastNLP.core.samplers.reproducible_sampler import RandomSampler, SortedSampler, SequentialSampler
  7. from tests.helpers.datasets.torch_data import TorchNormalDataset
  8. class TestRandomSamplerYh:
  9. def test_init(self):
  10. # 测试能否正确初始化
  11. dataset = TorchNormalDataset(num_of_data=100)
  12. sampler = RandomSampler(dataset)
  13. for i in sampler:
  14. pass
  15. def test_during_iter(self):
  16. dataset = TorchNormalDataset(num_of_data=100)
  17. sampler = RandomSampler(dataset)
  18. for i in sampler:
  19. with pytest.raises(AssertionError):
  20. sampler.set_distributed(1, 0)
  21. break
  22. # should not raise
  23. for i in sampler:
  24. pass
  25. sampler.set_distributed(1, 0)
  26. def test_set_distributed(self):
  27. dataset = TorchNormalDataset(num_of_data=100)
  28. sampler = RandomSampler(dataset, shuffle=False)
  29. sampler.set_distributed(num_replicas=2, rank=0, pad=False)
  30. assert len(sampler)==50
  31. count = 0
  32. for i in sampler:
  33. assert i%2==0
  34. count += 1
  35. assert count == 50
  36. sampler.set_distributed(num_replicas=2, rank=1, pad=False)
  37. assert len(sampler)==50
  38. count = 0
  39. for i in sampler:
  40. assert i%2==1
  41. count += 1
  42. assert count==50
  43. dataset = TorchNormalDataset(num_of_data=101)
  44. sampler = RandomSampler(dataset, shuffle=False)
  45. sampler.set_distributed(num_replicas=2, rank=0, pad=True)
  46. assert len(sampler)==51
  47. count = 0
  48. for i in sampler:
  49. assert i%2==0
  50. count += 1
  51. assert count == 51
  52. sampler.set_distributed(num_replicas=2, rank=1, pad=True)
  53. assert len(sampler) == 51
  54. count = 0
  55. for i in sampler:
  56. if i!=0:
  57. assert i%2==1
  58. count += 1
  59. assert count == 51
  60. def test_state_dict_check_length(self):
  61. dataset = TorchNormalDataset(num_of_data=100)
  62. sampler = RandomSampler(dataset, shuffle=False)
  63. states = sampler.state_dict()
  64. new_ds = TorchNormalDataset(num_of_data=10)
  65. with pytest.raises(AssertionError):
  66. new_sampler = RandomSampler(new_ds)
  67. new_sampler.load_state_dict(states)
  68. new_ds = TorchNormalDataset(num_of_data=100)
  69. new_sampler = RandomSampler(new_ds)
  70. new_sampler.load_state_dict(states)
  71. @pytest.mark.parametrize('pad', [True, False])
  72. @pytest.mark.parametrize('pre_shuffle', [True, False])
  73. @pytest.mark.parametrize('post_shuffle', [True, False])
  74. @pytest.mark.parametrize('num_consumed_samples', [0]+np.random.randint(1, 100, size=3).tolist())
  75. def test_state_dict(self, pad, pre_shuffle, post_shuffle, num_consumed_samples):
  76. num_samples = 100
  77. dataset = TorchNormalDataset(num_of_data=num_samples)
  78. # 测试使用 前后shuffle不一致的load操作
  79. sampler = RandomSampler(dataset, shuffle=pre_shuffle)
  80. sampler.set_epoch(0)
  81. already_numbers = set()
  82. if num_consumed_samples>0:
  83. for i, j in enumerate(sampler, start=1):
  84. already_numbers.add(j)
  85. if i == num_consumed_samples:
  86. break
  87. assert len(already_numbers) == num_consumed_samples
  88. states = sampler.state_dict()
  89. new_sampler = RandomSampler(dataset, shuffle=post_shuffle)
  90. new_sampler.load_state_dict(states)
  91. new_sampler.set_epoch(0)
  92. for i in new_sampler:
  93. assert i not in already_numbers
  94. # 测试切换成多卡也没有问题
  95. other_rank_number = set()
  96. for rank in range(3):
  97. new_sampler = RandomSampler(dataset, shuffle=post_shuffle)
  98. new_sampler.load_state_dict(states)
  99. new_sampler.set_distributed(num_replicas=3, rank=rank, pad=pad)
  100. new_sampler.set_epoch(0)
  101. count = 0
  102. seen = 0
  103. seen_in_other_rank = 0
  104. for i in new_sampler:
  105. seen_in_other_rank += int(i in other_rank_number)
  106. other_rank_number.add(i)
  107. seen += int(i in already_numbers)
  108. count += 1
  109. assert seen <= 1 if pad else seen == 0
  110. assert seen_in_other_rank<=1 # 因为pad可能重复
  111. @pytest.mark.parametrize('pad', [True, False])
  112. @pytest.mark.parametrize('pre_shuffle', [True, False])
  113. @pytest.mark.parametrize('post_shuffle', [True, False])
  114. @pytest.mark.parametrize('num_consumed_samples', [0]+np.random.randint(1, 100//2, size=3).tolist())
  115. def test_state_dict_2(self, pad, pre_shuffle, post_shuffle, num_consumed_samples):
  116. # 测试一下从多卡切换到单卡,或者切换到不同卡数量的多卡
  117. num_samples = 100
  118. dataset = TorchNormalDataset(num_of_data=num_samples)
  119. # 测试使用 前后shuffle不一致的load操作
  120. # lst = [30]
  121. already_numbers = set()
  122. sampler = RandomSampler(dataset, shuffle=pre_shuffle, seed=0)
  123. sampler.set_distributed(num_replicas=2, rank=0)
  124. sampler.set_epoch(0)
  125. if num_consumed_samples>0:
  126. for i, j in enumerate(sampler, start=1):
  127. already_numbers.add(j)
  128. if i == num_consumed_samples:
  129. break
  130. sampler = RandomSampler(dataset, shuffle=pre_shuffle, seed=0)
  131. sampler.set_epoch(0)
  132. sampler.set_distributed(num_replicas=2, rank=1)
  133. if num_consumed_samples>0:
  134. for i, j in enumerate(sampler, start=1):
  135. already_numbers.add(j)
  136. if i == num_consumed_samples:
  137. break
  138. assert len(already_numbers) == num_consumed_samples*2
  139. states = sampler.state_dict()
  140. new_sampler = RandomSampler(dataset, shuffle=post_shuffle)
  141. new_sampler.load_state_dict(states)
  142. new_sampler.set_epoch(0)
  143. for i in new_sampler:
  144. assert i not in already_numbers
  145. # 测试切换成多卡也没有问题
  146. other_rank_number = set()
  147. for rank in range(3):
  148. new_sampler = RandomSampler(dataset, shuffle=post_shuffle)
  149. new_sampler.load_state_dict(states)
  150. new_sampler.set_epoch(0)
  151. new_sampler.set_distributed(num_replicas=3, rank=rank, pad=pad)
  152. count = 0
  153. seen = 0
  154. seen_in_other_rank = 0
  155. for i in new_sampler:
  156. seen_in_other_rank += int(i in other_rank_number)
  157. other_rank_number.add(i)
  158. seen += int(i in already_numbers)
  159. count += 1
  160. assert seen <= 1 if pad else seen == 0
  161. assert seen_in_other_rank<=1 # 因为pad可能重复
  162. @pytest.mark.parametrize('shuffle', [True, False])
  163. @pytest.mark.parametrize('pad', [True, False])
  164. @pytest.mark.parametrize('num_samples', [13, 100, 623, 1000])
  165. @pytest.mark.parametrize('num_replicas', [1, 2, 3])
  166. def test_num_consumed_samples_array(self, shuffle, pad, num_samples, num_replicas):
  167. # 测试在 sampler 多生成的时候,可以仍然可以恢复
  168. dataset = DatasetWithVaryLength(num_of_data=num_samples)
  169. samplers = []
  170. for i in range(num_replicas):
  171. sampler = RandomSampler(dataset, shuffle=shuffle)
  172. sampler.set_epoch(0)
  173. sampler.set_distributed(num_replicas=num_replicas, rank=i, pad=pad)
  174. samplers.append(sampler)
  175. count = 0
  176. already_seen_sets = [set()]
  177. already_seen_set = set()
  178. for idxes in zip(*samplers):
  179. already_seen_set.update(idxes)
  180. already_seen_sets.append(deepcopy(already_seen_set))
  181. count += 1
  182. if count > 3:
  183. break
  184. states = samplers[0].state_dict()
  185. for i in range(len(already_seen_sets)):
  186. if states['num_consumed_samples_array'] is not None:
  187. states['num_consumed_samples'] = states['num_consumed_samples_array'][i]
  188. sampler = RandomSampler(dataset, shuffle=shuffle)
  189. already_seen_set = deepcopy(already_seen_sets[i])
  190. for batch in sampler:
  191. already_seen_set.add(batch)
  192. assert len(already_seen_set) == len(dataset)
  193. # 测试保存之后再次保存
  194. sampler = RandomSampler(dataset, shuffle=shuffle)
  195. sampler.set_epoch(0)
  196. if states['num_consumed_samples_array'] is not None:
  197. states['num_consumed_samples'] = states['num_consumed_samples_array'][2]
  198. if len(already_seen_sets)<3:
  199. return
  200. already_seen_set = already_seen_sets[2]
  201. count = 0
  202. for idx in sampler:
  203. already_seen_set.add(idx)
  204. count += 1
  205. if count > 6:
  206. break
  207. states = sampler.state_dict()
  208. if states['num_consumed_samples_array'] is not None:
  209. states['num_consumed_samples'] = states['num_consumed_samples_array'][count]
  210. sampler = RandomSampler(dataset, shuffle=shuffle)
  211. sampler.load_state_dict(states)
  212. sampler.set_epoch(0)
  213. for idx in sampler:
  214. already_seen_set.add(idx)
  215. assert len(already_seen_set)==len(dataset)
  216. class TestRandomSampler:
  217. # 测试单卡;
  218. def test_seed_work_when_shuffle_is_true(self):
  219. data_length = 100
  220. torch_normal_data = TorchNormalDataset(num_of_data=data_length)
  221. for shuffle in [True, False]:
  222. iterable = RandomSampler(dataset=torch_normal_data, shuffle=shuffle)
  223. # 迭代一些数据,但是不迭代完;
  224. iterable.set_epoch(1)
  225. iterator = iter(iterable)
  226. pre_data = []
  227. forward_steps = 30
  228. for _ in range(forward_steps):
  229. pre_data.append(next(iterator))
  230. # 看重新生成迭代器是否能够完全重置状态;
  231. iterator = iter(iterable)
  232. res = []
  233. for _ in range(forward_steps):
  234. res.append(next(iterator))
  235. assert pre_data == res
  236. # 测试断点重训;
  237. # 如果 shuffle,那么下一轮的数据应当与前一轮不一样;并且如果是断点重训,两次的下一轮应当是一样的;
  238. def test_2(self):
  239. data_length = 100
  240. torch_normal_data = TorchNormalDataset(num_of_data=data_length)
  241. random_sampler_1 = RandomSampler(dataset=torch_normal_data, shuffle=True)
  242. iterator = iter(random_sampler_1)
  243. # 第一轮
  244. random_sampler_1.set_epoch(0)
  245. first_epoch = []
  246. forward_steps = 30
  247. for _ in range(forward_steps):
  248. first_epoch.append(next(iterator))
  249. # 先提前保存断点重训的结果;
  250. state = random_sampler_1.state_dict()
  251. # 保存第一个 epoch 的之后的结果,用于查看断点重训是否正确;
  252. first_left_data = []
  253. while True:
  254. try:
  255. first_left_data.append(next(iterator))
  256. except StopIteration:
  257. break
  258. # 第二轮
  259. random_sampler_1.set_epoch(1)
  260. iterator = iter(random_sampler_1)
  261. second_epoch = []
  262. for _ in range(forward_steps):
  263. second_epoch.append(next(iterator))
  264. assert first_epoch != second_epoch
  265. # 重新加载第一轮的状态,查看断点重训是否正确;
  266. random_sampler_2 = RandomSampler(dataset=torch_normal_data, shuffle=True)
  267. random_sampler_2.load_state_dict(state)
  268. random_sampler_2.set_epoch(0)
  269. iterator = iter(random_sampler_2)
  270. re_first_epoch = []
  271. while True:
  272. try:
  273. re_first_epoch.append(next(iterator))
  274. except StopIteration:
  275. break
  276. assert re_first_epoch == first_left_data
  277. # 查看第二轮的结果是否也是和第一次的第二轮完全一致;
  278. random_sampler_2.set_epoch(1)
  279. iterator = iter(random_sampler_2)
  280. re_second_epoch = []
  281. for _ in range(forward_steps):
  282. re_second_epoch.append(next(iterator))
  283. assert re_second_epoch == second_epoch
  284. # 多卡;
  285. # 如果一个 sampler 还没有迭代完,我们又直接 iter(sampler) 那么是否正确(应当生成一个全新的 sampler)?
  286. def test_3(self):
  287. data_length = 100
  288. torch_normal_data = TorchNormalDataset(num_of_data=data_length)
  289. random_sampler_1 = partial(RandomSampler, dataset=torch_normal_data, shuffle=False)
  290. random_sampler_2 = partial(RandomSampler, dataset=torch_normal_data, shuffle=True)
  291. iterable_items = [random_sampler_1, random_sampler_2]
  292. world_size = 3
  293. for pad in {True, False}:
  294. for iterable in iterable_items:
  295. for rank in range(world_size):
  296. each_rank_iterable = iterable()
  297. each_rank_iterable.set_epoch(0)
  298. each_rank_iterable.set_distributed(num_replicas=world_size, rank=rank, pad=pad)
  299. # 迭代一些数据,但是不迭代完;
  300. iterator = iter(each_rank_iterable)
  301. pre_data = []
  302. forward_steps = 10
  303. for _ in range(forward_steps):
  304. pre_data.append(next(iterator))
  305. # 看重新生成迭代器是否能够完全重置状态;
  306. iterator = iter(each_rank_iterable)
  307. res = []
  308. for _ in range(forward_steps):
  309. res.append(next(iterator))
  310. assert res == pre_data
  311. # 测试断点重训;
  312. # 如果 shuffle,那么下一轮的数据应当与前一轮不一样;并且如果是断点重训,两次的下一轮应当是一样的;
  313. def test_4(self):
  314. data_length = 100
  315. torch_normal_data = TorchNormalDataset(num_of_data=data_length)
  316. random_sampler_1 = partial(RandomSampler, dataset=torch_normal_data, shuffle=True)
  317. world_size_1 = 2
  318. forward_steps = 10
  319. for pad in {True, False}:
  320. all_rank_state = {}
  321. all_rank_first_left_data = {}
  322. all_rank_second_epoch = {}
  323. for rank in range(world_size_1):
  324. each_rank_iterable = random_sampler_1()
  325. each_rank_iterable.set_distributed(num_replicas=world_size_1, rank=rank, pad=pad)
  326. iterator = iter(each_rank_iterable)
  327. # 第一轮
  328. each_rank_iterable.set_epoch(0)
  329. first_epoch = []
  330. for _ in range(forward_steps):
  331. first_epoch.append(next(iterator))
  332. # 先提前保存断点重训的结果;
  333. all_rank_state[rank] = each_rank_iterable.state_dict()
  334. # 保存第一个 epoch 的之后的结果,用于查看断点重训是否正确;
  335. first_left_data = []
  336. while True:
  337. try:
  338. first_left_data.append(next(iterator))
  339. except StopIteration:
  340. break
  341. all_rank_first_left_data[rank] = first_left_data
  342. # 第二轮
  343. each_rank_iterable.set_epoch(1)
  344. iterator = iter(each_rank_iterable)
  345. second_epoch = []
  346. for _ in range(forward_steps):
  347. second_epoch.append(next(iterator))
  348. all_rank_second_epoch[rank] = second_epoch
  349. assert first_epoch != second_epoch
  350. # 重新加载第一轮的状态,查看断点重训是否正确;
  351. random_sampler_2 = partial(RandomSampler, dataset=torch_normal_data, shuffle=True)
  352. for rank in range(world_size_1):
  353. each_rank_iterable = random_sampler_2()
  354. each_rank_iterable.set_distributed(num_replicas=world_size_1, rank=rank, pad=pad)
  355. each_rank_iterable.load_state_dict(all_rank_state[rank])
  356. each_rank_iterable.set_epoch(0)
  357. iterator = iter(each_rank_iterable)
  358. re_first_epoch = []
  359. while True:
  360. try:
  361. re_first_epoch.append(next(iterator))
  362. except StopIteration:
  363. break
  364. assert re_first_epoch == all_rank_first_left_data[rank]
  365. # 查看第二轮的结果是否也是和第一次的第二轮完全一致;
  366. each_rank_iterable.set_epoch(1)
  367. iterator = iter(each_rank_iterable)
  368. re_second_epoch = []
  369. for _ in range(forward_steps):
  370. re_second_epoch.append(next(iterator))
  371. assert re_second_epoch == all_rank_second_epoch[rank]
  372. # todo 测试 ddp 时 world_size 改变的断点重训;
  373. def test_5(self):
  374. ...
  375. class DatasetWithVaryLength:
  376. def __init__(self, num_of_data=100, reverse=False):
  377. self.data = np.arange(num_of_data)
  378. if reverse:
  379. self.data = self.data[::-1]
  380. def __getitem__(self, item):
  381. return self.data[item]
  382. def __len__(self):
  383. return len(self.data)
  384. class TestSortedSampler:
  385. def test_single(self):
  386. num_of_data = 100
  387. data = DatasetWithVaryLength(num_of_data)
  388. sampler = SortedSampler(data, length=data.data)
  389. indexes = list(sampler)
  390. assert indexes==list(range(num_of_data-1, -1, -1))
  391. @pytest.mark.parametrize('pad', [True, False])
  392. @pytest.mark.parametrize('num_replicas', [2, 3])
  393. @pytest.mark.parametrize('num_of_data', [2, 3, 4, 100])
  394. def test_multi(self, pad, num_replica, num_of_data):
  395. data = DatasetWithVaryLength(num_of_data=num_of_data)
  396. samplers = []
  397. for i in range(num_replica):
  398. sampler = SortedSampler(dataset=data, length=data.data)
  399. sampler.set_distributed(num_replica, rank=i, pad=pad)
  400. samplers.append(sampler)
  401. # 保证顺序是没乱的
  402. already_seen_index = set()
  403. for sampler in samplers:
  404. larger_count = 0 # 这里为 0 就可以,因为最后补充的index一定是比较大的数。
  405. prev_index = float('inf')
  406. cur_set = set()
  407. seen_in_other_rank = 0
  408. for index in sampler:
  409. seen_in_other_rank += int(index in already_seen_index) # 不同的卡不交叉
  410. cur_set.add(index)
  411. larger_count += int(index <= prev_index)
  412. prev_index = index
  413. assert larger_count+1 >= len(sampler) # 除了最后一个可能乱掉,其它都必须要保持这个顺序
  414. assert seen_in_other_rank <= 1 if pad else seen_in_other_rank == 0
  415. already_seen_index.update(cur_set)
  416. indexes = list(chain(*samplers))
  417. indexes = set(indexes)
  418. if pad:
  419. assert indexes == set(range(num_of_data))
  420. else:
  421. assert len(indexes) <= num_of_data
  422. @pytest.mark.parametrize('pad', [True, False])
  423. @pytest.mark.parametrize('num_consumed_samples', [0]+np.random.randint(1, 100, size=3).tolist())
  424. def test_state_dict(self, pad, num_consumed_samples):
  425. num_samples = 100
  426. dataset = DatasetWithVaryLength(num_of_data=num_samples)
  427. # 测试使用 前后shuffle不一致的load操作
  428. sampler = SortedSampler(dataset, length=dataset.data)
  429. sampler.set_epoch(0)
  430. already_numbers = set()
  431. if num_consumed_samples>0:
  432. for i, j in enumerate(sampler, start=1):
  433. if already_numbers:
  434. assert j<max(already_numbers)
  435. already_numbers.add(j)
  436. if i == num_consumed_samples:
  437. break
  438. assert len(already_numbers) == num_consumed_samples
  439. states = sampler.state_dict()
  440. new_sampler = SortedSampler(dataset, length=dataset.data)
  441. new_sampler.load_state_dict(states)
  442. new_sampler.set_epoch(0)
  443. for i in new_sampler:
  444. if already_numbers:
  445. assert i < max(already_numbers)
  446. assert i not in already_numbers
  447. # 测试切换成多卡也没有问题
  448. other_rank_number = set()
  449. for rank in range(3):
  450. new_sampler = SortedSampler(dataset, length=dataset.data)
  451. new_sampler.load_state_dict(states)
  452. new_sampler.set_distributed(num_replicas=3, rank=rank, pad=pad)
  453. new_sampler.set_epoch(0)
  454. count = 0
  455. seen = 0
  456. seen_in_other_rank = 0
  457. smaller = 0
  458. for i in new_sampler:
  459. if already_numbers:
  460. smaller += int(i >= max(already_numbers))
  461. seen_in_other_rank += int(i in other_rank_number)
  462. other_rank_number.add(i)
  463. seen += int(i in already_numbers)
  464. count += 1
  465. assert seen <= 1 if pad else seen == 0
  466. assert seen_in_other_rank<=1 # 因为pad可能重复
  467. assert smaller<=1 if pad else smaller==0
  468. @pytest.mark.parametrize('pad', [True, False])
  469. @pytest.mark.parametrize('num_consumed_samples', [0]+np.random.randint(1, 100//2, size=3).tolist())
  470. def test_state_dict_2(self, pad, num_consumed_samples):
  471. # 测试一下从多卡切换到单卡,或者切换到不同卡数量的多卡
  472. num_samples = 100
  473. dataset = DatasetWithVaryLength(num_of_data=num_samples)
  474. # 测试使用 前后shuffle不一致的load操作
  475. # lst = [30]
  476. already_numbers = set()
  477. sampler = SortedSampler(dataset, length=dataset.data)
  478. sampler.set_distributed(num_replicas=2, rank=0)
  479. sampler.set_epoch(0)
  480. if num_consumed_samples>0:
  481. for i, j in enumerate(sampler, start=1):
  482. if already_numbers:
  483. assert j<=max(already_numbers)
  484. already_numbers.add(j)
  485. if i == num_consumed_samples:
  486. break
  487. sampler = SortedSampler(dataset, length=dataset.data)
  488. sampler.set_epoch(0)
  489. sampler.set_distributed(num_replicas=2, rank=1)
  490. if num_consumed_samples>0:
  491. for i, j in enumerate(sampler, start=1):
  492. already_numbers.add(j)
  493. if i == num_consumed_samples:
  494. break
  495. assert len(already_numbers) == num_consumed_samples*2
  496. states = sampler.state_dict()
  497. new_sampler = SortedSampler(dataset, length=dataset.data)
  498. new_sampler.load_state_dict(states)
  499. new_sampler.set_epoch(0)
  500. for i in new_sampler:
  501. if already_numbers:
  502. assert i < max(already_numbers)
  503. assert i not in already_numbers
  504. # 测试切换成多卡也没有问题
  505. other_rank_number = set()
  506. for rank in range(3):
  507. new_sampler = SortedSampler(dataset, length=dataset.data)
  508. new_sampler.load_state_dict(states)
  509. new_sampler.set_epoch(0)
  510. new_sampler.set_distributed(num_replicas=3, rank=rank, pad=pad)
  511. count = 0
  512. seen = 0
  513. seen_in_other_rank = 0
  514. smaller = 0
  515. for i in new_sampler:
  516. if already_numbers:
  517. smaller += int(i>=max(already_numbers))
  518. seen_in_other_rank += int(i in other_rank_number)
  519. other_rank_number.add(i)
  520. seen += int(i in already_numbers)
  521. count += 1
  522. assert seen <= 1 if pad else seen == 0
  523. assert seen_in_other_rank<=1 # 因为pad可能重复
  524. assert smaller <= 1 if pad else smaller == 0
  525. class TestSequentialSampler:
  526. def test_single(self):
  527. num_of_data = 100
  528. data = DatasetWithVaryLength(num_of_data)
  529. sampler = SequentialSampler(data)
  530. indexes = list(sampler)
  531. assert indexes==list(range(num_of_data))
  532. @pytest.mark.parametrize('pad', [True, False])
  533. @pytest.mark.parametrize('num_replicas', [2, 3])
  534. @pytest.mark.parametrize('num_of_data', [2, 3, 4, 100])
  535. def test_multi(self, pad, num_replica, num_of_data):
  536. data = DatasetWithVaryLength(num_of_data=num_of_data)
  537. samplers = []
  538. for i in range(num_replica):
  539. sampler = SequentialSampler(dataset=data)
  540. sampler.set_distributed(num_replica, rank=i, pad=pad)
  541. samplers.append(sampler)
  542. # 保证顺序是没乱的
  543. already_seen_index = set()
  544. for idx, sampler in enumerate(samplers):
  545. larger_count = 1
  546. prev_index = float('inf')
  547. cur_set = set()
  548. seen_in_other_rank = 0
  549. for index in sampler:
  550. seen_in_other_rank += int(index in already_seen_index) # 不同的卡不交叉
  551. cur_set.add(index)
  552. larger_count += int(index >= prev_index)
  553. prev_index = index
  554. assert larger_count+1 >= len(sampler) # 除了最后一个可能乱掉,其它都必须要保持这个顺序
  555. assert seen_in_other_rank <= idx if pad else seen_in_other_rank == 0
  556. already_seen_index.update(cur_set)
  557. indexes = list(chain(*samplers))
  558. indexes = set(indexes)
  559. if pad:
  560. assert indexes == set(range(num_of_data))
  561. else:
  562. assert len(indexes) <= num_of_data
  563. @pytest.mark.parametrize('pad', [True, False])
  564. @pytest.mark.parametrize('num_consumed_samples', [0]+np.random.randint(1, 100, size=3).tolist())
  565. def test_state_dict(self, pad, num_consumed_samples):
  566. num_samples = 100
  567. dataset = DatasetWithVaryLength(num_of_data=num_samples)
  568. # 测试使用 前后shuffle不一致的load操作
  569. sampler = SequentialSampler(dataset=dataset)
  570. sampler.set_epoch(0)
  571. already_numbers = set()
  572. if num_consumed_samples>0:
  573. for i, j in enumerate(sampler, start=1):
  574. if already_numbers:
  575. assert j>max(already_numbers)
  576. already_numbers.add(j)
  577. if i == num_consumed_samples:
  578. break
  579. assert len(already_numbers) == num_consumed_samples
  580. states = sampler.state_dict()
  581. new_sampler = SequentialSampler(dataset=dataset)
  582. new_sampler.load_state_dict(states)
  583. new_sampler.set_epoch(0)
  584. for i in new_sampler:
  585. if already_numbers:
  586. assert i > max(already_numbers)
  587. assert i not in already_numbers
  588. # 测试切换成多卡也没有问题
  589. other_rank_number = set()
  590. for rank in range(3):
  591. new_sampler = SequentialSampler(dataset=dataset)
  592. new_sampler.load_state_dict(states)
  593. new_sampler.set_distributed(num_replicas=3, rank=rank, pad=pad)
  594. new_sampler.set_epoch(0)
  595. count = 0
  596. seen = 0
  597. seen_in_other_rank = 0
  598. smaller = 0
  599. for i in new_sampler:
  600. if already_numbers:
  601. smaller += int(i <= max(already_numbers))
  602. seen_in_other_rank += int(i in other_rank_number)
  603. other_rank_number.add(i)
  604. seen += int(i in already_numbers)
  605. count += 1
  606. assert seen <= 1 if pad else seen == 0
  607. assert seen_in_other_rank<=rank # 因为pad可能重复
  608. assert smaller<=1 if pad else smaller==0
  609. @pytest.mark.parametrize('pad', [True, False])
  610. @pytest.mark.parametrize('num_consumed_samples', [0]+np.random.randint(1, 100//2, size=3).tolist())
  611. def test_state_dict_2(self, pad, num_consumed_samples):
  612. # 测试一下从多卡切换到单卡,或者切换到不同卡数量的多卡
  613. num_samples = 100
  614. dataset = DatasetWithVaryLength(num_of_data=num_samples)
  615. # 测试使用 前后shuffle不一致的load操作
  616. # lst = [30]
  617. already_numbers = set()
  618. sampler = SequentialSampler(dataset=dataset)
  619. sampler.set_distributed(num_replicas=2, rank=0)
  620. sampler.set_epoch(0)
  621. if num_consumed_samples>0:
  622. for i, j in enumerate(sampler, start=1):
  623. if already_numbers:
  624. assert j>max(already_numbers)
  625. already_numbers.add(j)
  626. if i == num_consumed_samples:
  627. break
  628. sampler = SequentialSampler(dataset=dataset)
  629. sampler.set_epoch(0)
  630. sampler.set_distributed(num_replicas=2, rank=1)
  631. if num_consumed_samples>0:
  632. for i, j in enumerate(sampler, start=1):
  633. already_numbers.add(j)
  634. if i == num_consumed_samples:
  635. break
  636. assert len(already_numbers) == num_consumed_samples*2
  637. states = sampler.state_dict()
  638. new_sampler = SequentialSampler(dataset=dataset)
  639. new_sampler.load_state_dict(states)
  640. new_sampler.set_epoch(0)
  641. for i in new_sampler:
  642. if already_numbers:
  643. assert i > max(already_numbers)
  644. assert i not in already_numbers
  645. # 测试切换成多卡也没有问题
  646. other_rank_number = set()
  647. for rank in range(3):
  648. new_sampler = SequentialSampler(dataset=dataset)
  649. new_sampler.load_state_dict(states)
  650. new_sampler.set_epoch(0)
  651. new_sampler.set_distributed(num_replicas=3, rank=rank, pad=pad)
  652. count = 0
  653. seen = 0
  654. seen_in_other_rank = 0
  655. smaller = 0
  656. for i in new_sampler:
  657. if already_numbers:
  658. smaller += int(i<max(already_numbers))
  659. seen_in_other_rank += int(i in other_rank_number)
  660. other_rank_number.add(i)
  661. seen += int(i in already_numbers)
  662. count += 1
  663. assert seen <= 1 if pad else seen == 0
  664. assert seen_in_other_rank<=1 # 因为pad可能重复
  665. assert smaller <= rank if pad else smaller == 0