You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

test_reproducible_sampler.py 15 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363
  1. import unittest
  2. from itertools import product
  3. import numpy as np
  4. from functools import partial
  5. from array import array
  6. from fastNLP.core.samplers.reproducible_sampler import RandomSampler
  7. from fastNLP.core.drivers.torch_driver.utils import replace_batch_sampler
  8. from tests.helpers.datasets.torch_data import TorchNormalDataset
  9. class TestRandomSamplerYh(unittest.TestCase):
  10. def test_init(self):
  11. # 测试能否正确初始化
  12. dataset = TorchNormalDataset(num_of_data=100)
  13. sampler = RandomSampler(dataset)
  14. for i in sampler:
  15. pass
  16. def test_during_iter(self):
  17. dataset = TorchNormalDataset(num_of_data=100)
  18. sampler = RandomSampler(dataset)
  19. for i in sampler:
  20. with self.assertRaises(AssertionError):
  21. sampler.set_distributed(1, 0)
  22. break
  23. # should not raise
  24. for i in sampler:
  25. pass
  26. sampler.set_distributed(1, 0)
  27. def test_set_distributed(self):
  28. dataset = TorchNormalDataset(num_of_data=100)
  29. sampler = RandomSampler(dataset, shuffle=False)
  30. sampler.set_distributed(num_replicas=2, rank=0, pad=False)
  31. self.assertEqual(len(sampler), 50)
  32. count = 0
  33. for i in sampler:
  34. self.assertEqual(i%2, 0)
  35. count += 1
  36. self.assertEqual(count, 50)
  37. sampler.set_distributed(num_replicas=2, rank=1, pad=False)
  38. self.assertEqual(len(sampler), 50)
  39. count = 0
  40. for i in sampler:
  41. self.assertEqual(i%2, 1)
  42. count += 1
  43. self.assertEqual(count, 50)
  44. dataset = TorchNormalDataset(num_of_data=101)
  45. sampler = RandomSampler(dataset, shuffle=False)
  46. sampler.set_distributed(num_replicas=2, rank=0, pad=True)
  47. self.assertEqual(len(sampler), 51)
  48. count = 0
  49. for i in sampler:
  50. self.assertEqual(i%2, 0)
  51. count += 1
  52. self.assertEqual(count, 51)
  53. sampler.set_distributed(num_replicas=2, rank=1, pad=True)
  54. self.assertEqual(len(sampler), 51)
  55. count = 0
  56. for i in sampler:
  57. if i!=0:
  58. self.assertEqual(i%2, 1)
  59. count += 1
  60. self.assertEqual(count, 51)
  61. def test_state_dict_check_length(self):
  62. dataset = TorchNormalDataset(num_of_data=100)
  63. sampler = RandomSampler(dataset, shuffle=False)
  64. states = sampler.state_dict()
  65. new_ds = TorchNormalDataset(num_of_data=10)
  66. with self.assertRaises(AssertionError):
  67. new_sampler = RandomSampler(new_ds)
  68. new_sampler.load_state_dict(states)
  69. new_ds = TorchNormalDataset(num_of_data=100)
  70. new_sampler = RandomSampler(new_ds)
  71. new_sampler.load_state_dict(states)
  72. def test_state_dict(self):
  73. num_samples = 100
  74. dataset = TorchNormalDataset(num_of_data=num_samples)
  75. # 测试使用 前后shuffle不一致的load操作
  76. lst = [0]+np.random.randint(1, num_samples, size=3).tolist()
  77. for pre_shuffle, post_shuffle, num_consumed_samples in product([True, False], [True, False],
  78. lst):
  79. with self.subTest(pre_shuffle=pre_shuffle, post_shuffle=post_shuffle, num_consumed_samples=num_consumed_samples):
  80. sampler = RandomSampler(dataset, shuffle=pre_shuffle)
  81. sampler.set_epoch(0)
  82. already_numbers = set()
  83. if num_consumed_samples>0:
  84. for i, j in enumerate(sampler, start=1):
  85. already_numbers.add(j)
  86. if i == num_consumed_samples:
  87. break
  88. self.assertEqual(len(already_numbers), num_consumed_samples)
  89. states = sampler.state_dict()
  90. new_sampler = RandomSampler(dataset, shuffle=post_shuffle)
  91. new_sampler.load_state_dict(states)
  92. new_sampler.set_epoch(0)
  93. for i in new_sampler:
  94. self.assertNotIn(i, already_numbers)
  95. # 测试切换成多卡也没有问题
  96. other_rank_number = set()
  97. for rank in range(3):
  98. new_sampler = RandomSampler(dataset, shuffle=post_shuffle)
  99. new_sampler.load_state_dict(states)
  100. new_sampler.set_distributed(num_replicas=3, rank=rank, pad=False)
  101. new_sampler.set_epoch(0)
  102. count = 0
  103. for i in new_sampler:
  104. self.assertNotIn(i, other_rank_number)
  105. other_rank_number.add(i)
  106. self.assertNotIn(i, already_numbers)
  107. count += 1
  108. def test_state_dict_2(self):
  109. # 测试一下从多卡切换到单卡,或者切换到不同卡数量的多卡
  110. num_samples = 100
  111. dataset = TorchNormalDataset(num_of_data=num_samples)
  112. # 测试使用 前后shuffle不一致的load操作
  113. lst = [0]+np.random.randint(1, num_samples//2, size=3).tolist()
  114. # lst = [30]
  115. for pre_shuffle, post_shuffle, num_consumed_samples in product([True, False], [True, False],
  116. lst):
  117. with self.subTest(pre_shuffle=pre_shuffle, post_shuffle=post_shuffle, num_consumed_samples=num_consumed_samples):
  118. already_numbers = set()
  119. sampler = RandomSampler(dataset, shuffle=pre_shuffle, seed=0)
  120. sampler.set_distributed(num_replicas=2, rank=0)
  121. sampler.set_epoch(0)
  122. if num_consumed_samples>0:
  123. for i, j in enumerate(sampler, start=1):
  124. already_numbers.add(j)
  125. if i == num_consumed_samples:
  126. break
  127. sampler = RandomSampler(dataset, shuffle=pre_shuffle, seed=0)
  128. sampler.set_epoch(0)
  129. sampler.set_distributed(num_replicas=2, rank=1)
  130. if num_consumed_samples>0:
  131. for i, j in enumerate(sampler, start=1):
  132. already_numbers.add(j)
  133. if i == num_consumed_samples:
  134. break
  135. self.assertEqual(len(already_numbers), num_consumed_samples*2)
  136. states = sampler.state_dict()
  137. new_sampler = RandomSampler(dataset, shuffle=post_shuffle)
  138. new_sampler.load_state_dict(states)
  139. new_sampler.set_epoch(0)
  140. for i in new_sampler:
  141. self.assertNotIn(i, already_numbers)
  142. # 测试切换成多卡也没有问题
  143. other_rank_number = set()
  144. for rank in range(3):
  145. new_sampler = RandomSampler(dataset, shuffle=post_shuffle)
  146. new_sampler.load_state_dict(states)
  147. new_sampler.set_epoch(0)
  148. new_sampler.set_distributed(num_replicas=3, rank=rank, pad=False)
  149. count = 0
  150. for i in new_sampler:
  151. self.assertNotIn(i, other_rank_number)
  152. other_rank_number.add(i)
  153. self.assertNotIn(i, already_numbers)
  154. count += 1
  155. class TestRandomSampler(unittest.TestCase):
  156. # 测试单卡;
  157. def test_seed_work_when_shuffle_is_true(self):
  158. data_length = 100
  159. torch_normal_data = TorchNormalDataset(num_of_data=data_length)
  160. for shuffle in [True, False]:
  161. iterable = RandomSampler(dataset=torch_normal_data, shuffle=shuffle)
  162. # 迭代一些数据,但是不迭代完;
  163. iterable.set_epoch(1)
  164. iterator = iter(iterable)
  165. pre_data = []
  166. forward_steps = 30
  167. for _ in range(forward_steps):
  168. pre_data.append(next(iterator))
  169. # 看重新生成迭代器是否能够完全重置状态;
  170. iterator = iter(iterable)
  171. res = []
  172. for _ in range(forward_steps):
  173. res.append(next(iterator))
  174. assert pre_data == res
  175. # 测试断点重训;
  176. # 如果 shuffle,那么下一轮的数据应当与前一轮不一样;并且如果是断点重训,两次的下一轮应当是一样的;
  177. def test_2(self):
  178. data_length = 100
  179. torch_normal_data = TorchNormalDataset(num_of_data=data_length)
  180. random_sampler_1 = RandomSampler(dataset=torch_normal_data, shuffle=True)
  181. iterator = iter(random_sampler_1)
  182. # 第一轮
  183. random_sampler_1.set_epoch(0)
  184. first_epoch = []
  185. forward_steps = 30
  186. for _ in range(forward_steps):
  187. first_epoch.append(next(iterator))
  188. # 先提前保存断点重训的结果;
  189. state = random_sampler_1.state_dict()
  190. # 保存第一个 epoch 的之后的结果,用于查看断点重训是否正确;
  191. first_left_data = []
  192. while True:
  193. try:
  194. first_left_data.append(next(iterator))
  195. except StopIteration:
  196. break
  197. # 第二轮
  198. random_sampler_1.set_epoch(1)
  199. iterator = iter(random_sampler_1)
  200. second_epoch = []
  201. for _ in range(forward_steps):
  202. second_epoch.append(next(iterator))
  203. assert first_epoch != second_epoch
  204. # 重新加载第一轮的状态,查看断点重训是否正确;
  205. random_sampler_2 = RandomSampler(dataset=torch_normal_data, shuffle=True)
  206. random_sampler_2.load_state_dict(state)
  207. random_sampler_2.set_epoch(0)
  208. iterator = iter(random_sampler_2)
  209. re_first_epoch = []
  210. while True:
  211. try:
  212. re_first_epoch.append(next(iterator))
  213. except StopIteration:
  214. break
  215. assert re_first_epoch == first_left_data
  216. # 查看第二轮的结果是否也是和第一次的第二轮完全一致;
  217. random_sampler_2.set_epoch(1)
  218. iterator = iter(random_sampler_2)
  219. re_second_epoch = []
  220. for _ in range(forward_steps):
  221. re_second_epoch.append(next(iterator))
  222. assert re_second_epoch == second_epoch
  223. # 多卡;
  224. # 如果一个 sampler 还没有迭代完,我们又直接 iter(sampler) 那么是否正确(应当生成一个全新的 sampler)?
  225. def test_3(self):
  226. data_length = 100
  227. torch_normal_data = TorchNormalDataset(num_of_data=data_length)
  228. random_sampler_1 = partial(RandomSampler, dataset=torch_normal_data, shuffle=False)
  229. random_sampler_2 = partial(RandomSampler, dataset=torch_normal_data, shuffle=True)
  230. iterable_items = [random_sampler_1, random_sampler_2]
  231. world_size = 3
  232. for pad in {True, False}:
  233. for iterable in iterable_items:
  234. for rank in range(world_size):
  235. each_rank_iterable = iterable()
  236. each_rank_iterable.set_epoch(0)
  237. each_rank_iterable.set_distributed(num_replicas=world_size, rank=rank, pad=pad)
  238. # 迭代一些数据,但是不迭代完;
  239. iterator = iter(each_rank_iterable)
  240. pre_data = []
  241. forward_steps = 10
  242. for _ in range(forward_steps):
  243. pre_data.append(next(iterator))
  244. # 看重新生成迭代器是否能够完全重置状态;
  245. iterator = iter(each_rank_iterable)
  246. res = []
  247. for _ in range(forward_steps):
  248. res.append(next(iterator))
  249. assert res == pre_data
  250. # 测试断点重训;
  251. # 如果 shuffle,那么下一轮的数据应当与前一轮不一样;并且如果是断点重训,两次的下一轮应当是一样的;
  252. def test_4(self):
  253. data_length = 100
  254. torch_normal_data = TorchNormalDataset(num_of_data=data_length)
  255. random_sampler_1 = partial(RandomSampler, dataset=torch_normal_data, shuffle=True)
  256. world_size_1 = 2
  257. forward_steps = 10
  258. for pad in {True, False}:
  259. all_rank_state = {}
  260. all_rank_first_left_data = {}
  261. all_rank_second_epoch = {}
  262. for rank in range(world_size_1):
  263. each_rank_iterable = random_sampler_1()
  264. each_rank_iterable.set_distributed(num_replicas=world_size_1, rank=rank, pad=pad)
  265. iterator = iter(each_rank_iterable)
  266. # 第一轮
  267. each_rank_iterable.set_epoch(0)
  268. first_epoch = []
  269. for _ in range(forward_steps):
  270. first_epoch.append(next(iterator))
  271. # 先提前保存断点重训的结果;
  272. all_rank_state[rank] = each_rank_iterable.state_dict()
  273. # 保存第一个 epoch 的之后的结果,用于查看断点重训是否正确;
  274. first_left_data = []
  275. while True:
  276. try:
  277. first_left_data.append(next(iterator))
  278. except StopIteration:
  279. break
  280. all_rank_first_left_data[rank] = first_left_data
  281. # 第二轮
  282. each_rank_iterable.set_epoch(1)
  283. iterator = iter(each_rank_iterable)
  284. second_epoch = []
  285. for _ in range(forward_steps):
  286. second_epoch.append(next(iterator))
  287. all_rank_second_epoch[rank] = second_epoch
  288. assert first_epoch != second_epoch
  289. # 重新加载第一轮的状态,查看断点重训是否正确;
  290. random_sampler_2 = partial(RandomSampler, dataset=torch_normal_data, shuffle=True)
  291. for rank in range(world_size_1):
  292. each_rank_iterable = random_sampler_2()
  293. each_rank_iterable.set_distributed(num_replicas=world_size_1, rank=rank, pad=pad)
  294. each_rank_iterable.load_state_dict(all_rank_state[rank])
  295. each_rank_iterable.set_epoch(0)
  296. iterator = iter(each_rank_iterable)
  297. re_first_epoch = []
  298. while True:
  299. try:
  300. re_first_epoch.append(next(iterator))
  301. except StopIteration:
  302. break
  303. assert re_first_epoch == all_rank_first_left_data[rank]
  304. # 查看第二轮的结果是否也是和第一次的第二轮完全一致;
  305. each_rank_iterable.set_epoch(1)
  306. iterator = iter(each_rank_iterable)
  307. re_second_epoch = []
  308. for _ in range(forward_steps):
  309. re_second_epoch.append(next(iterator))
  310. assert re_second_epoch == all_rank_second_epoch[rank]
  311. # todo 测试 ddp 时 world_size 改变的断点重训;
  312. def test_5(self):
  313. ...