| @@ -9,153 +9,153 @@ from fastNLP.core.samplers import RandomBatchSampler, BucketedBatchSampler | |||
| from fastNLP.core.drivers.torch_driver.utils import replace_batch_sampler | |||
| from tests.helpers.datasets.torch_data import TorchNormalDataset | |||
| class TestReproducibleBatchSampler: | |||
| # TODO 拆分测试,在这里只测试一个东西 | |||
| def test_torch_dataloader_1(self): | |||
| import torch | |||
| from torch.utils.data import DataLoader | |||
| # no shuffle | |||
| before_batch_size = 7 | |||
| dataset = TorchNormalDataset(num_of_data=100) | |||
| dataloader = DataLoader(dataset, batch_size=before_batch_size) | |||
| re_batchsampler = RandomBatchSampler(dataloader.batch_sampler, dataloader.batch_size, drop_last=False) | |||
| dataloader = replace_batch_sampler(dataloader, re_batchsampler) | |||
| forward_steps = 3 | |||
| iter_dataloader = iter(dataloader) | |||
| for _ in range(forward_steps): | |||
| next(iter_dataloader) | |||
| # 1. 保存状态 | |||
| _get_re_batchsampler = dataloader.batch_sampler | |||
| assert isinstance(_get_re_batchsampler, RandomBatchSampler) | |||
| state = _get_re_batchsampler.state_dict() | |||
| assert state == {"index_list": array("I", list(range(100))), "num_consumed_samples": forward_steps*before_batch_size, | |||
| "sampler_type": "RandomBatchSampler"} | |||
| # 2. 断点重训,重新生成一个 dataloader; | |||
| # 不改变 batch_size; | |||
| dataloader = DataLoader(dataset, batch_size=before_batch_size) | |||
| re_batchsampler = RandomBatchSampler(dataloader.batch_sampler, dataloader.batch_size, drop_last=False) | |||
| re_batchsampler.load_state_dict(state) | |||
| dataloader = replace_batch_sampler(dataloader, re_batchsampler) | |||
| real_res = [] | |||
| supposed_res = (torch.tensor(list(range(21, 28))), torch.tensor(list(range(28, 35)))) | |||
| forward_steps = 2 | |||
| iter_dataloader = iter(dataloader) | |||
| for _ in range(forward_steps): | |||
| real_res.append(next(iter_dataloader)) | |||
| for i in range(forward_steps): | |||
| assert all(real_res[i] == supposed_res[i]) | |||
| # 改变 batch_size; | |||
| after_batch_size = 3 | |||
| dataloader = DataLoader(dataset, batch_size=after_batch_size) | |||
| re_batchsampler = RandomBatchSampler(dataloader.batch_sampler, dataloader.batch_size, drop_last=False) | |||
| re_batchsampler.load_state_dict(state) | |||
| dataloader = replace_batch_sampler(dataloader, re_batchsampler) | |||
| real_res = [] | |||
| supposed_res = (torch.tensor(list(range(21, 24))), torch.tensor(list(range(24, 27)))) | |||
| forward_steps = 2 | |||
| iter_dataloader = iter(dataloader) | |||
| for _ in range(forward_steps): | |||
| real_res.append(next(iter_dataloader)) | |||
| for i in range(forward_steps): | |||
| assert all(real_res[i] == supposed_res[i]) | |||
| # 断点重训的第二轮是否是一个完整的 dataloader; | |||
| # 先把断点重训所在的那一个 epoch 跑完; | |||
| begin_idx = 27 | |||
| while True: | |||
| try: | |||
| data = next(iter_dataloader) | |||
| _batch_size = len(data) | |||
| assert all(data == torch.tensor(list(range(begin_idx, begin_idx + _batch_size)))) | |||
| begin_idx += _batch_size | |||
| except StopIteration: | |||
| break | |||
| # 开始新的一轮; | |||
| begin_idx = 0 | |||
| iter_dataloader = iter(dataloader) | |||
| while True: | |||
| try: | |||
| data = next(iter_dataloader) | |||
| _batch_size = len(data) | |||
| assert all(data == torch.tensor(list(range(begin_idx, begin_idx + _batch_size)))) | |||
| begin_idx += _batch_size | |||
| except StopIteration: | |||
| break | |||
| def test_torch_dataloader_2(self): | |||
| # 测试新的一轮的 index list 是重新生成的,而不是沿用上一轮的; | |||
| from torch.utils.data import DataLoader | |||
| # no shuffle | |||
| before_batch_size = 7 | |||
| dataset = TorchNormalDataset(num_of_data=100) | |||
| # 开启 shuffle,来检验断点重训后的第二轮的 index list 是不是重新生成的; | |||
| dataloader = DataLoader(dataset, batch_size=before_batch_size, shuffle=True) | |||
| re_batchsampler = RandomBatchSampler(dataloader.batch_sampler, dataloader.batch_size, drop_last=False) | |||
| dataloader = replace_batch_sampler(dataloader, re_batchsampler) | |||
| # 将一轮的所有数据保存下来,看是否恢复的是正确的; | |||
| all_supposed_data = [] | |||
| forward_steps = 3 | |||
| iter_dataloader = iter(dataloader) | |||
| for _ in range(forward_steps): | |||
| all_supposed_data.extend(next(iter_dataloader).tolist()) | |||
| # 1. 保存状态 | |||
| _get_re_batchsampler = dataloader.batch_sampler | |||
| assert isinstance(_get_re_batchsampler, RandomBatchSampler) | |||
| state = _get_re_batchsampler.state_dict() | |||
| # 2. 断点重训,重新生成一个 dataloader; | |||
| # 不改变 batch_size; | |||
| dataloader = DataLoader(dataset, batch_size=before_batch_size, shuffle=True) | |||
| re_batchsampler = RandomBatchSampler(dataloader.batch_sampler, dataloader.batch_size, drop_last=False) | |||
| re_batchsampler.load_state_dict(state) | |||
| dataloader = replace_batch_sampler(dataloader, re_batchsampler) | |||
| # 先把这一轮的数据过完; | |||
| pre_index_list = dataloader.batch_sampler.state_dict()["index_list"] | |||
| while True: | |||
| try: | |||
| all_supposed_data.extend(next(iter_dataloader).tolist()) | |||
| except StopIteration: | |||
| break | |||
| assert all_supposed_data == list(pre_index_list) | |||
| # 重新开启新的一轮; | |||
| for _ in range(3): | |||
| iter_dataloader = iter(dataloader) | |||
| res = [] | |||
| while True: | |||
| try: | |||
| res.append(next(iter_dataloader)) | |||
| except StopIteration: | |||
| break | |||
| def test_3(self): | |||
| import torch | |||
| from torch.utils.data import DataLoader | |||
| before_batch_size = 7 | |||
| dataset = TorchNormalDataset(num_of_data=100) | |||
| # 开启 shuffle,来检验断点重训后的第二轮的 index list 是不是重新生成的; | |||
| dataloader = DataLoader(dataset, batch_size=before_batch_size) | |||
| for idx, data in enumerate(dataloader): | |||
| if idx > 3: | |||
| break | |||
| iterator = iter(dataloader) | |||
| for each in iterator: | |||
| pass | |||
| # | |||
| # class TestReproducibleBatchSampler: | |||
| # # TODO 拆分测试,在这里只测试一个东西 | |||
| # def test_torch_dataloader_1(self): | |||
| # import torch | |||
| # from torch.utils.data import DataLoader | |||
| # # no shuffle | |||
| # before_batch_size = 7 | |||
| # dataset = TorchNormalDataset(num_of_data=100) | |||
| # dataloader = DataLoader(dataset, batch_size=before_batch_size) | |||
| # re_batchsampler = RandomBatchSampler(dataloader.batch_sampler, dataloader.batch_size, drop_last=False) | |||
| # dataloader = replace_batch_sampler(dataloader, re_batchsampler) | |||
| # | |||
| # forward_steps = 3 | |||
| # iter_dataloader = iter(dataloader) | |||
| # for _ in range(forward_steps): | |||
| # next(iter_dataloader) | |||
| # | |||
| # # 1. 保存状态 | |||
| # _get_re_batchsampler = dataloader.batch_sampler | |||
| # assert isinstance(_get_re_batchsampler, RandomBatchSampler) | |||
| # state = _get_re_batchsampler.state_dict() | |||
| # assert state == {"index_list": array("I", list(range(100))), "num_consumed_samples": forward_steps*before_batch_size, | |||
| # "sampler_type": "RandomBatchSampler"} | |||
| # | |||
| # # 2. 断点重训,重新生成一个 dataloader; | |||
| # # 不改变 batch_size; | |||
| # dataloader = DataLoader(dataset, batch_size=before_batch_size) | |||
| # re_batchsampler = RandomBatchSampler(dataloader.batch_sampler, dataloader.batch_size, drop_last=False) | |||
| # re_batchsampler.load_state_dict(state) | |||
| # dataloader = replace_batch_sampler(dataloader, re_batchsampler) | |||
| # | |||
| # real_res = [] | |||
| # supposed_res = (torch.tensor(list(range(21, 28))), torch.tensor(list(range(28, 35)))) | |||
| # forward_steps = 2 | |||
| # iter_dataloader = iter(dataloader) | |||
| # for _ in range(forward_steps): | |||
| # real_res.append(next(iter_dataloader)) | |||
| # | |||
| # for i in range(forward_steps): | |||
| # assert all(real_res[i] == supposed_res[i]) | |||
| # | |||
| # # 改变 batch_size; | |||
| # after_batch_size = 3 | |||
| # dataloader = DataLoader(dataset, batch_size=after_batch_size) | |||
| # re_batchsampler = RandomBatchSampler(dataloader.batch_sampler, dataloader.batch_size, drop_last=False) | |||
| # re_batchsampler.load_state_dict(state) | |||
| # dataloader = replace_batch_sampler(dataloader, re_batchsampler) | |||
| # | |||
| # real_res = [] | |||
| # supposed_res = (torch.tensor(list(range(21, 24))), torch.tensor(list(range(24, 27)))) | |||
| # forward_steps = 2 | |||
| # iter_dataloader = iter(dataloader) | |||
| # for _ in range(forward_steps): | |||
| # real_res.append(next(iter_dataloader)) | |||
| # | |||
| # for i in range(forward_steps): | |||
| # assert all(real_res[i] == supposed_res[i]) | |||
| # | |||
| # # 断点重训的第二轮是否是一个完整的 dataloader; | |||
| # # 先把断点重训所在的那一个 epoch 跑完; | |||
| # begin_idx = 27 | |||
| # while True: | |||
| # try: | |||
| # data = next(iter_dataloader) | |||
| # _batch_size = len(data) | |||
| # assert all(data == torch.tensor(list(range(begin_idx, begin_idx + _batch_size)))) | |||
| # begin_idx += _batch_size | |||
| # except StopIteration: | |||
| # break | |||
| # | |||
| # # 开始新的一轮; | |||
| # begin_idx = 0 | |||
| # iter_dataloader = iter(dataloader) | |||
| # while True: | |||
| # try: | |||
| # data = next(iter_dataloader) | |||
| # _batch_size = len(data) | |||
| # assert all(data == torch.tensor(list(range(begin_idx, begin_idx + _batch_size)))) | |||
| # begin_idx += _batch_size | |||
| # except StopIteration: | |||
| # break | |||
| # | |||
| # def test_torch_dataloader_2(self): | |||
| # # 测试新的一轮的 index list 是重新生成的,而不是沿用上一轮的; | |||
| # from torch.utils.data import DataLoader | |||
| # # no shuffle | |||
| # before_batch_size = 7 | |||
| # dataset = TorchNormalDataset(num_of_data=100) | |||
| # # 开启 shuffle,来检验断点重训后的第二轮的 index list 是不是重新生成的; | |||
| # dataloader = DataLoader(dataset, batch_size=before_batch_size, shuffle=True) | |||
| # re_batchsampler = RandomBatchSampler(dataloader.batch_sampler, dataloader.batch_size, drop_last=False) | |||
| # dataloader = replace_batch_sampler(dataloader, re_batchsampler) | |||
| # | |||
| # # 将一轮的所有数据保存下来,看是否恢复的是正确的; | |||
| # all_supposed_data = [] | |||
| # forward_steps = 3 | |||
| # iter_dataloader = iter(dataloader) | |||
| # for _ in range(forward_steps): | |||
| # all_supposed_data.extend(next(iter_dataloader).tolist()) | |||
| # | |||
| # # 1. 保存状态 | |||
| # _get_re_batchsampler = dataloader.batch_sampler | |||
| # assert isinstance(_get_re_batchsampler, RandomBatchSampler) | |||
| # state = _get_re_batchsampler.state_dict() | |||
| # | |||
| # # 2. 断点重训,重新生成一个 dataloader; | |||
| # # 不改变 batch_size; | |||
| # dataloader = DataLoader(dataset, batch_size=before_batch_size, shuffle=True) | |||
| # re_batchsampler = RandomBatchSampler(dataloader.batch_sampler, dataloader.batch_size, drop_last=False) | |||
| # re_batchsampler.load_state_dict(state) | |||
| # dataloader = replace_batch_sampler(dataloader, re_batchsampler) | |||
| # | |||
| # # 先把这一轮的数据过完; | |||
| # pre_index_list = dataloader.batch_sampler.state_dict()["index_list"] | |||
| # while True: | |||
| # try: | |||
| # all_supposed_data.extend(next(iter_dataloader).tolist()) | |||
| # except StopIteration: | |||
| # break | |||
| # assert all_supposed_data == list(pre_index_list) | |||
| # | |||
| # # 重新开启新的一轮; | |||
| # for _ in range(3): | |||
| # iter_dataloader = iter(dataloader) | |||
| # res = [] | |||
| # while True: | |||
| # try: | |||
| # res.append(next(iter_dataloader)) | |||
| # except StopIteration: | |||
| # break | |||
| # | |||
| # def test_3(self): | |||
| # import torch | |||
| # from torch.utils.data import DataLoader | |||
| # before_batch_size = 7 | |||
| # dataset = TorchNormalDataset(num_of_data=100) | |||
| # # 开启 shuffle,来检验断点重训后的第二轮的 index list 是不是重新生成的; | |||
| # dataloader = DataLoader(dataset, batch_size=before_batch_size) | |||
| # | |||
| # for idx, data in enumerate(dataloader): | |||
| # if idx > 3: | |||
| # break | |||
| # | |||
| # iterator = iter(dataloader) | |||
| # for each in iterator: | |||
| # pass | |||
| class DatasetWithVaryLength: | |||