| @@ -8,7 +8,6 @@ import math | |||
| from copy import deepcopy | |||
| from typing import Dict, Union, List | |||
| from itertools import chain | |||
| import os | |||
| import numpy as np | |||
| @@ -70,7 +70,7 @@ def model_and_optimizers(): | |||
| @pytest.mark.parametrize("callbacks", [[RecordTrainerEventTriggerCallback()]]) | |||
| @pytest.mark.torch | |||
| @magic_argv_env_context | |||
| def test_trainer_event_trigger( | |||
| def test_trainer_event_trigger_1( | |||
| model_and_optimizers: TrainerParameters, | |||
| driver, | |||
| device, | |||
| @@ -104,5 +104,126 @@ def test_trainer_event_trigger( | |||
| assert member.value in output[0] | |||
| @pytest.mark.parametrize("driver,device", [("torch", "cpu"),("torch", 6), ("torch", [6, 7])]) # , ("torch", 6), ("torch", [6, 7]) | |||
| @pytest.mark.torch | |||
| @magic_argv_env_context | |||
| def test_trainer_event_trigger_2( | |||
| model_and_optimizers: TrainerParameters, | |||
| driver, | |||
| device, | |||
| n_epochs=2, | |||
| ): | |||
| @Trainer.on(Events.on_after_trainer_initialized) | |||
| def on_after_trainer_initialized(trainer, driver): | |||
| print("on_after_trainer_initialized") | |||
| @Trainer.on(Events.on_sanity_check_begin) | |||
| def on_sanity_check_begin(trainer): | |||
| print("on_sanity_check_begin") | |||
| @Trainer.on(Events.on_sanity_check_end) | |||
| def on_sanity_check_end(trainer, sanity_check_res): | |||
| print("on_sanity_check_end") | |||
| @Trainer.on(Events.on_train_begin) | |||
| def on_train_begin(trainer): | |||
| print("on_train_begin") | |||
| @Trainer.on(Events.on_train_end) | |||
| def on_train_end(trainer): | |||
| print("on_train_end") | |||
| @Trainer.on(Events.on_train_epoch_begin) | |||
| def on_train_epoch_begin(trainer): | |||
| if trainer.cur_epoch_idx >= 1: | |||
| # 触发 on_exception; | |||
| raise Exception | |||
| print("on_train_epoch_begin") | |||
| @Trainer.on(Events.on_train_epoch_end) | |||
| def on_train_epoch_end(trainer): | |||
| print("on_train_epoch_end") | |||
| @Trainer.on(Events.on_fetch_data_begin) | |||
| def on_fetch_data_begin(trainer): | |||
| print("on_fetch_data_begin") | |||
| @Trainer.on(Events.on_fetch_data_end) | |||
| def on_fetch_data_end(trainer): | |||
| print("on_fetch_data_end") | |||
| @Trainer.on(Events.on_train_batch_begin) | |||
| def on_train_batch_begin(trainer, batch, indices=None): | |||
| print("on_train_batch_begin") | |||
| @Trainer.on(Events.on_train_batch_end) | |||
| def on_train_batch_end(trainer): | |||
| print("on_train_batch_end") | |||
| @Trainer.on(Events.on_exception) | |||
| def on_exception(trainer, exception): | |||
| print("on_exception") | |||
| @Trainer.on(Events.on_before_backward) | |||
| def on_before_backward(trainer, outputs): | |||
| print("on_before_backward") | |||
| @Trainer.on(Events.on_after_backward) | |||
| def on_after_backward(trainer): | |||
| print("on_after_backward") | |||
| @Trainer.on(Events.on_before_optimizers_step) | |||
| def on_before_optimizers_step(trainer, optimizers): | |||
| print("on_before_optimizers_step") | |||
| @Trainer.on(Events.on_after_optimizers_step) | |||
| def on_after_optimizers_step(trainer, optimizers): | |||
| print("on_after_optimizers_step") | |||
| @Trainer.on(Events.on_before_zero_grad) | |||
| def on_before_zero_grad(trainer, optimizers): | |||
| print("on_before_zero_grad") | |||
| @Trainer.on(Events.on_after_zero_grad) | |||
| def on_after_zero_grad(trainer, optimizers): | |||
| print("on_after_zero_grad") | |||
| @Trainer.on(Events.on_evaluate_begin) | |||
| def on_evaluate_begin(trainer): | |||
| print("on_evaluate_begin") | |||
| @Trainer.on(Events.on_evaluate_end) | |||
| def on_evaluate_end(trainer, results): | |||
| print("on_evaluate_end") | |||
| with pytest.raises(Exception): | |||
| with Capturing() as output: | |||
| trainer = Trainer( | |||
| model=model_and_optimizers.model, | |||
| driver=driver, | |||
| device=device, | |||
| optimizers=model_and_optimizers.optimizers, | |||
| train_dataloader=model_and_optimizers.train_dataloader, | |||
| evaluate_dataloaders=model_and_optimizers.evaluate_dataloaders, | |||
| input_mapping=model_and_optimizers.input_mapping, | |||
| output_mapping=model_and_optimizers.output_mapping, | |||
| metrics=model_and_optimizers.metrics, | |||
| n_epochs=n_epochs, | |||
| ) | |||
| trainer.run() | |||
| if dist.is_initialized(): | |||
| dist.destroy_process_group() | |||
| for name, member in Events.__members__.items(): | |||
| assert member.value in output[0] | |||
| @@ -1,7 +1,7 @@ | |||
| from functools import reduce | |||
| from fastNLP.core.controllers.utils.utils import _TruncatedDataLoader # TODO: 该类修改过,记得将 test 也修改; | |||
| from tests.helpers.datasets.normal_data import NormalIterator | |||
| from tests.helpers.datasets.normal_data import NormalSampler | |||
| class Test_WrapDataLoader: | |||
| @@ -9,7 +9,7 @@ class Test_WrapDataLoader: | |||
| def test_normal_generator(self): | |||
| all_sanity_batches = [4, 20, 100] | |||
| for sanity_batches in all_sanity_batches: | |||
| data = NormalIterator(num_of_data=1000) | |||
| data = NormalSampler(num_of_data=1000) | |||
| wrapper = _TruncatedDataLoader(dataloader=data, num_batches=sanity_batches) | |||
| dataloader = iter(wrapper) | |||
| mark = 0 | |||
| @@ -1,161 +1,131 @@ | |||
| from array import array | |||
| import numpy as np | |||
| import pytest | |||
| from itertools import chain | |||
| from copy import deepcopy | |||
| from array import array | |||
| from tests.helpers.datasets.normal_data import NormalSampler, NormalBatchSampler | |||
| from fastNLP.core.samplers import ReproduceBatchSampler, BucketedBatchSampler, RandomBatchSampler | |||
| from fastNLP.core.drivers.torch_driver.utils import replace_batch_sampler | |||
| from tests.helpers.datasets.torch_data import TorchNormalDataset | |||
| # | |||
| # class TestReproducibleBatchSampler: | |||
| # # TODO 拆分测试,在这里只测试一个东西 | |||
| # def test_torch_dataloader_1(self): | |||
| # import torch | |||
| # from torch.utils.data import DataLoader | |||
| # # no shuffle | |||
| # before_batch_size = 7 | |||
| # dataset = TorchNormalDataset(num_of_data=100) | |||
| # dataloader = DataLoader(dataset, batch_size=before_batch_size) | |||
| # re_batchsampler = ReproduceBatchSampler(dataloader.batch_sampler, dataloader.batch_size, drop_last=False) | |||
| # dataloader = replace_batch_sampler(dataloader, re_batchsampler) | |||
| # | |||
| # forward_steps = 3 | |||
| # iter_dataloader = iter(dataloader) | |||
| # for _ in range(forward_steps): | |||
| # next(iter_dataloader) | |||
| # | |||
| # # 1. 保存状态 | |||
| # _get_re_batchsampler = dataloader.batch_sampler | |||
| # assert isinstance(_get_re_batchsampler, ReproduceBatchSampler) | |||
| # state = _get_re_batchsampler.state_dict() | |||
| # assert state == {"index_list": array("I", list(range(100))), "num_consumed_samples": forward_steps*before_batch_size, | |||
| # "sampler_type": "ReproduceBatchSampler"} | |||
| # | |||
| # # 2. 断点重训,重新生成一个 dataloader; | |||
| # # 不改变 batch_size; | |||
| # dataloader = DataLoader(dataset, batch_size=before_batch_size) | |||
| # re_batchsampler = ReproduceBatchSampler(dataloader.batch_sampler, dataloader.batch_size, drop_last=False) | |||
| # re_batchsampler.load_state_dict(state) | |||
| # dataloader = replace_batch_sampler(dataloader, re_batchsampler) | |||
| # | |||
| # real_res = [] | |||
| # supposed_res = (torch.tensor(list(range(21, 28))), torch.tensor(list(range(28, 35)))) | |||
| # forward_steps = 2 | |||
| # iter_dataloader = iter(dataloader) | |||
| # for _ in range(forward_steps): | |||
| # real_res.append(next(iter_dataloader)) | |||
| # | |||
| # for i in range(forward_steps): | |||
| # assert all(real_res[i] == supposed_res[i]) | |||
| # | |||
| # # 改变 batch_size; | |||
| # after_batch_size = 3 | |||
| # dataloader = DataLoader(dataset, batch_size=after_batch_size) | |||
| # re_batchsampler = ReproduceBatchSampler(dataloader.batch_sampler, dataloader.batch_size, drop_last=False) | |||
| # re_batchsampler.load_state_dict(state) | |||
| # dataloader = replace_batch_sampler(dataloader, re_batchsampler) | |||
| # | |||
| # real_res = [] | |||
| # supposed_res = (torch.tensor(list(range(21, 24))), torch.tensor(list(range(24, 27)))) | |||
| # forward_steps = 2 | |||
| # iter_dataloader = iter(dataloader) | |||
| # for _ in range(forward_steps): | |||
| # real_res.append(next(iter_dataloader)) | |||
| # | |||
| # for i in range(forward_steps): | |||
| # assert all(real_res[i] == supposed_res[i]) | |||
| # | |||
| # # 断点重训的第二轮是否是一个完整的 dataloader; | |||
| # # 先把断点重训所在的那一个 epoch 跑完; | |||
| # begin_idx = 27 | |||
| # while True: | |||
| # try: | |||
| # data = next(iter_dataloader) | |||
| # _batch_size = len(data) | |||
| # assert all(data == torch.tensor(list(range(begin_idx, begin_idx + _batch_size)))) | |||
| # begin_idx += _batch_size | |||
| # except StopIteration: | |||
| # break | |||
| # | |||
| # # 开始新的一轮; | |||
| # begin_idx = 0 | |||
| # iter_dataloader = iter(dataloader) | |||
| # while True: | |||
| # try: | |||
| # data = next(iter_dataloader) | |||
| # _batch_size = len(data) | |||
| # assert all(data == torch.tensor(list(range(begin_idx, begin_idx + _batch_size)))) | |||
| # begin_idx += _batch_size | |||
| # except StopIteration: | |||
| # break | |||
| # | |||
| # def test_torch_dataloader_2(self): | |||
| # # 测试新的一轮的 index list 是重新生成的,而不是沿用上一轮的; | |||
| # from torch.utils.data import DataLoader | |||
| # # no shuffle | |||
| # before_batch_size = 7 | |||
| # dataset = TorchNormalDataset(num_of_data=100) | |||
| # # 开启 shuffle,来检验断点重训后的第二轮的 index list 是不是重新生成的; | |||
| # dataloader = DataLoader(dataset, batch_size=before_batch_size, shuffle=True) | |||
| # re_batchsampler = ReproduceBatchSampler(dataloader.batch_sampler, dataloader.batch_size, drop_last=False) | |||
| # dataloader = replace_batch_sampler(dataloader, re_batchsampler) | |||
| # | |||
| # # 将一轮的所有数据保存下来,看是否恢复的是正确的; | |||
| # all_supposed_data = [] | |||
| # forward_steps = 3 | |||
| # iter_dataloader = iter(dataloader) | |||
| # for _ in range(forward_steps): | |||
| # all_supposed_data.extend(next(iter_dataloader).tolist()) | |||
| # | |||
| # # 1. 保存状态 | |||
| # _get_re_batchsampler = dataloader.batch_sampler | |||
| # assert isinstance(_get_re_batchsampler, ReproduceBatchSampler) | |||
| # state = _get_re_batchsampler.state_dict() | |||
| # | |||
| # # 2. 断点重训,重新生成一个 dataloader; | |||
| # # 不改变 batch_size; | |||
| # dataloader = DataLoader(dataset, batch_size=before_batch_size, shuffle=True) | |||
| # re_batchsampler = ReproduceBatchSampler(dataloader.batch_sampler, dataloader.batch_size, drop_last=False) | |||
| # re_batchsampler.load_state_dict(state) | |||
| # dataloader = replace_batch_sampler(dataloader, re_batchsampler) | |||
| # | |||
| # # 先把这一轮的数据过完; | |||
| # pre_index_list = dataloader.batch_sampler.state_dict()["index_list"] | |||
| # while True: | |||
| # try: | |||
| # all_supposed_data.extend(next(iter_dataloader).tolist()) | |||
| # except StopIteration: | |||
| # break | |||
| # assert all_supposed_data == list(pre_index_list) | |||
| # | |||
| # # 重新开启新的一轮; | |||
| # for _ in range(3): | |||
| # iter_dataloader = iter(dataloader) | |||
| # res = [] | |||
| # while True: | |||
| # try: | |||
| # res.append(next(iter_dataloader)) | |||
| # except StopIteration: | |||
| # break | |||
| # | |||
| # def test_3(self): | |||
| # import torch | |||
| # from torch.utils.data import DataLoader | |||
| # before_batch_size = 7 | |||
| # dataset = TorchNormalDataset(num_of_data=100) | |||
| # # 开启 shuffle,来检验断点重训后的第二轮的 index list 是不是重新生成的; | |||
| # dataloader = DataLoader(dataset, batch_size=before_batch_size) | |||
| # | |||
| # for idx, data in enumerate(dataloader): | |||
| # if idx > 3: | |||
| # break | |||
| # | |||
| # iterator = iter(dataloader) | |||
| # for each in iterator: | |||
| # pass | |||
| class TestReproducibleBatchSampler: | |||
| def test_1(self): | |||
| sampler = NormalSampler(num_of_data=100) # 这里是否是 batchsampler 不影响; | |||
| reproduce_batch_sampler = ReproduceBatchSampler(sampler, batch_size=4, drop_last=False) | |||
| forward_steps = 3 | |||
| iterator = iter(reproduce_batch_sampler) | |||
| i = 0 | |||
| while i < forward_steps: | |||
| next(iterator) | |||
| i += 1 | |||
| # 保存状态; | |||
| state = reproduce_batch_sampler.state_dict() | |||
| assert state == {"index_list": array("I", list(range(100))), | |||
| "num_consumed_samples": forward_steps * 4, | |||
| "sampler_type": "ReproduceBatchSampler"} | |||
| # 重新生成一个 batchsampler 然后加载状态; | |||
| sampler = NormalSampler(num_of_data=100) # 这里是否是 batchsampler 不影响; | |||
| reproduce_batch_sampler = ReproduceBatchSampler(sampler, batch_size=4, drop_last=False) | |||
| reproduce_batch_sampler.load_state_dict(state) | |||
| real_res = [] | |||
| supposed_res = (list(range(12, 16)), list(range(16, 20))) | |||
| forward_steps = 2 | |||
| iter_dataloader = iter(reproduce_batch_sampler) | |||
| for _ in range(forward_steps): | |||
| real_res.append(next(iter_dataloader)) | |||
| for i in range(forward_steps): | |||
| assert supposed_res[i] == real_res[i] | |||
| # 改变 batchsize; | |||
| sampler = NormalSampler(num_of_data=100) # 这里是否是 batchsampler 不影响; | |||
| reproduce_batch_sampler = ReproduceBatchSampler(sampler, batch_size=7, drop_last=False) | |||
| reproduce_batch_sampler.load_state_dict(state) | |||
| real_res = [] | |||
| supposed_res = (list(range(12, 19)), list(range(19, 26))) | |||
| forward_steps = 2 | |||
| iter_dataloader = iter(reproduce_batch_sampler) | |||
| for _ in range(forward_steps): | |||
| real_res.append(next(iter_dataloader)) | |||
| for i in range(forward_steps): | |||
| assert supposed_res[i] == real_res[i] | |||
| # 断点重训的第二轮是否是一个完整的 dataloader; | |||
| # 先把断点重训所在的那一个 epoch 跑完; | |||
| begin_idx = 26 | |||
| while True: | |||
| try: | |||
| data = next(iter_dataloader) | |||
| _batch_size = len(data) | |||
| assert data == list(range(begin_idx, begin_idx + _batch_size)) | |||
| begin_idx += _batch_size | |||
| except StopIteration: | |||
| break | |||
| # 开始新的一轮; | |||
| begin_idx = 0 | |||
| iter_dataloader = iter(reproduce_batch_sampler) | |||
| while True: | |||
| try: | |||
| data = next(iter_dataloader) | |||
| _batch_size = len(data) | |||
| assert data == list(range(begin_idx, begin_idx + _batch_size)) | |||
| begin_idx += _batch_size | |||
| except StopIteration: | |||
| break | |||
| def test_2(self): | |||
| # 测试新的一轮的 index list 是重新生成的,而不是沿用上一轮的; | |||
| before_batch_size = 7 | |||
| sampler = NormalSampler(num_of_data=100) | |||
| # 开启 shuffle,来检验断点重训后的第二轮的 index list 是不是重新生成的; | |||
| reproduce_batch_sampler = ReproduceBatchSampler(sampler, before_batch_size, drop_last=False) | |||
| # 将一轮的所有数据保存下来,看是否恢复的是正确的; | |||
| all_supposed_data = [] | |||
| forward_steps = 3 | |||
| iter_dataloader = iter(reproduce_batch_sampler) | |||
| for _ in range(forward_steps): | |||
| all_supposed_data.extend(next(iter_dataloader)) | |||
| # 1. 保存状态 | |||
| state = reproduce_batch_sampler.state_dict() | |||
| # 2. 断点重训,重新生成一个 dataloader; | |||
| # 不改变 batch_size; | |||
| sampler = NormalSampler(num_of_data=100, shuffle=True) | |||
| reproduce_batch_sampler = ReproduceBatchSampler(sampler, before_batch_size, drop_last=False) | |||
| reproduce_batch_sampler.load_state_dict(state) | |||
| # 先把这一轮的数据过完; | |||
| pre_index_list = reproduce_batch_sampler.state_dict()["index_list"] | |||
| iter_dataloader = iter(reproduce_batch_sampler) | |||
| while True: | |||
| try: | |||
| all_supposed_data.extend(next(iter_dataloader)) | |||
| except StopIteration: | |||
| break | |||
| assert all_supposed_data == list(pre_index_list) | |||
| # 重新开启新的一轮; | |||
| for _ in range(3): | |||
| iter_dataloader = iter(reproduce_batch_sampler) | |||
| res = [] | |||
| while True: | |||
| try: | |||
| res.extend(next(iter_dataloader)) | |||
| except StopIteration: | |||
| break | |||
| assert res != all_supposed_data | |||
| class DatasetWithVaryLength: | |||
| @@ -0,0 +1,141 @@ | |||
| from array import array | |||
| import torch | |||
| from torch.utils.data import DataLoader | |||
| import pytest | |||
| from fastNLP.core.samplers import ReproduceBatchSampler | |||
| from fastNLP.core.drivers.torch_driver.utils import replace_batch_sampler | |||
| from tests.helpers.datasets.torch_data import TorchNormalDataset | |||
| @pytest.mark.torch | |||
| class TestReproducibleBatchSamplerTorch: | |||
| def test_torch_dataloader_1(self): | |||
| # no shuffle | |||
| before_batch_size = 7 | |||
| dataset = TorchNormalDataset(num_of_data=100) | |||
| dataloader = DataLoader(dataset, batch_size=before_batch_size) | |||
| re_batchsampler = ReproduceBatchSampler(dataloader.batch_sampler, dataloader.batch_size, drop_last=False) | |||
| dataloader = replace_batch_sampler(dataloader, re_batchsampler) | |||
| forward_steps = 3 | |||
| iter_dataloader = iter(dataloader) | |||
| for _ in range(forward_steps): | |||
| next(iter_dataloader) | |||
| # 1. 保存状态 | |||
| _get_re_batchsampler = dataloader.batch_sampler | |||
| assert isinstance(_get_re_batchsampler, ReproduceBatchSampler) | |||
| state = _get_re_batchsampler.state_dict() | |||
| assert state == {"index_list": array("I", list(range(100))), "num_consumed_samples": forward_steps*before_batch_size, | |||
| "sampler_type": "ReproduceBatchSampler"} | |||
| # 2. 断点重训,重新生成一个 dataloader; | |||
| # 不改变 batch_size; | |||
| dataloader = DataLoader(dataset, batch_size=before_batch_size) | |||
| re_batchsampler = ReproduceBatchSampler(dataloader.batch_sampler, dataloader.batch_size, drop_last=False) | |||
| re_batchsampler.load_state_dict(state) | |||
| dataloader = replace_batch_sampler(dataloader, re_batchsampler) | |||
| real_res = [] | |||
| supposed_res = (torch.tensor(list(range(21, 28))), torch.tensor(list(range(28, 35)))) | |||
| forward_steps = 2 | |||
| iter_dataloader = iter(dataloader) | |||
| for _ in range(forward_steps): | |||
| real_res.append(next(iter_dataloader)) | |||
| for i in range(forward_steps): | |||
| assert all(real_res[i] == supposed_res[i]) | |||
| # 改变 batch_size; | |||
| after_batch_size = 3 | |||
| dataloader = DataLoader(dataset, batch_size=after_batch_size) | |||
| re_batchsampler = ReproduceBatchSampler(dataloader.batch_sampler, dataloader.batch_size, drop_last=False) | |||
| re_batchsampler.load_state_dict(state) | |||
| dataloader = replace_batch_sampler(dataloader, re_batchsampler) | |||
| real_res = [] | |||
| supposed_res = (torch.tensor(list(range(21, 24))), torch.tensor(list(range(24, 27)))) | |||
| forward_steps = 2 | |||
| iter_dataloader = iter(dataloader) | |||
| for _ in range(forward_steps): | |||
| real_res.append(next(iter_dataloader)) | |||
| for i in range(forward_steps): | |||
| assert all(real_res[i] == supposed_res[i]) | |||
| # 断点重训的第二轮是否是一个完整的 dataloader; | |||
| # 先把断点重训所在的那一个 epoch 跑完; | |||
| begin_idx = 27 | |||
| while True: | |||
| try: | |||
| data = next(iter_dataloader) | |||
| _batch_size = len(data) | |||
| assert all(data == torch.tensor(list(range(begin_idx, begin_idx + _batch_size)))) | |||
| begin_idx += _batch_size | |||
| except StopIteration: | |||
| break | |||
| # 开始新的一轮; | |||
| begin_idx = 0 | |||
| iter_dataloader = iter(dataloader) | |||
| while True: | |||
| try: | |||
| data = next(iter_dataloader) | |||
| _batch_size = len(data) | |||
| assert all(data == torch.tensor(list(range(begin_idx, begin_idx + _batch_size)))) | |||
| begin_idx += _batch_size | |||
| except StopIteration: | |||
| break | |||
| def test_torch_dataloader_2(self): | |||
| # 测试新的一轮的 index list 是重新生成的,而不是沿用上一轮的; | |||
| from torch.utils.data import DataLoader | |||
| before_batch_size = 7 | |||
| dataset = TorchNormalDataset(num_of_data=100) | |||
| # 开启 shuffle,来检验断点重训后的第二轮的 index list 是不是重新生成的; | |||
| dataloader = DataLoader(dataset, batch_size=before_batch_size, shuffle=True) | |||
| re_batchsampler = ReproduceBatchSampler(dataloader.batch_sampler, dataloader.batch_size, drop_last=False) | |||
| dataloader = replace_batch_sampler(dataloader, re_batchsampler) | |||
| # 将一轮的所有数据保存下来,看是否恢复的是正确的; | |||
| all_supposed_data = [] | |||
| forward_steps = 3 | |||
| iter_dataloader = iter(dataloader) | |||
| for _ in range(forward_steps): | |||
| all_supposed_data.extend(next(iter_dataloader).tolist()) | |||
| # 1. 保存状态 | |||
| _get_re_batchsampler = dataloader.batch_sampler | |||
| assert isinstance(_get_re_batchsampler, ReproduceBatchSampler) | |||
| state = _get_re_batchsampler.state_dict() | |||
| # 2. 断点重训,重新生成一个 dataloader; | |||
| # 不改变 batch_size; | |||
| dataloader = DataLoader(dataset, batch_size=before_batch_size, shuffle=True) | |||
| re_batchsampler = ReproduceBatchSampler(dataloader.batch_sampler, dataloader.batch_size, drop_last=False) | |||
| re_batchsampler.load_state_dict(state) | |||
| dataloader = replace_batch_sampler(dataloader, re_batchsampler) | |||
| iter_dataloader = iter(dataloader) | |||
| # 先把这一轮的数据过完; | |||
| pre_index_list = dataloader.batch_sampler.state_dict()["index_list"] | |||
| while True: | |||
| try: | |||
| all_supposed_data.extend(next(iter_dataloader).tolist()) | |||
| except StopIteration: | |||
| break | |||
| assert all_supposed_data == list(pre_index_list) | |||
| # 重新开启新的一轮; | |||
| for _ in range(3): | |||
| iter_dataloader = iter(dataloader) | |||
| res = [] | |||
| while True: | |||
| try: | |||
| res.extend(next(iter_dataloader).tolist()) | |||
| except StopIteration: | |||
| break | |||
| assert res != all_supposed_data | |||
| @@ -1,13 +1,25 @@ | |||
| import numpy as np | |||
| import random | |||
| class NormalIterator: | |||
| def __init__(self, num_of_data=1000): | |||
| class NormalSampler: | |||
| def __init__(self, num_of_data=1000, shuffle=False): | |||
| self._num_of_data = num_of_data | |||
| self._data = list(range(num_of_data)) | |||
| if shuffle: | |||
| random.shuffle(self._data) | |||
| self.shuffle = shuffle | |||
| self._index = 0 | |||
| self.need_reinitialize = False | |||
| def __iter__(self): | |||
| if self.need_reinitialize: | |||
| self._index = 0 | |||
| if self.shuffle: | |||
| random.shuffle(self._data) | |||
| else: | |||
| self.need_reinitialize = True | |||
| return self | |||
| def __next__(self): | |||
| @@ -15,12 +27,45 @@ class NormalIterator: | |||
| raise StopIteration | |||
| _data = self._data[self._index] | |||
| self._index += 1 | |||
| return self._data | |||
| return _data | |||
| def __len__(self): | |||
| return self._num_of_data | |||
| class NormalBatchSampler: | |||
| def __init__(self, sampler, batch_size: int, drop_last: bool) -> None: | |||
| # Since collections.abc.Iterable does not check for `__getitem__`, which | |||
| # is one way for an object to be an iterable, we don't do an `isinstance` | |||
| # check here. | |||
| if not isinstance(batch_size, int) or isinstance(batch_size, bool) or \ | |||
| batch_size <= 0: | |||
| raise ValueError("batch_size should be a positive integer value, " | |||
| "but got batch_size={}".format(batch_size)) | |||
| if not isinstance(drop_last, bool): | |||
| raise ValueError("drop_last should be a boolean value, but got " | |||
| "drop_last={}".format(drop_last)) | |||
| self.sampler = sampler | |||
| self.batch_size = batch_size | |||
| self.drop_last = drop_last | |||
| def __iter__(self): | |||
| batch = [] | |||
| for idx in self.sampler: | |||
| batch.append(idx) | |||
| if len(batch) == self.batch_size: | |||
| yield batch | |||
| batch = [] | |||
| if len(batch) > 0 and not self.drop_last: | |||
| yield batch | |||
| def __len__(self) -> int: | |||
| if self.drop_last: | |||
| return len(self.sampler) // self.batch_size | |||
| else: | |||
| return (len(self.sampler) + self.batch_size - 1) // self.batch_size | |||
| class RandomDataset: | |||
| def __init__(self, num_data=10): | |||
| self.data = np.random.rand(num_data) | |||
| @@ -29,4 +74,7 @@ class RandomDataset: | |||
| return len(self.data) | |||
| def __getitem__(self, item): | |||
| return self.data[item] | |||
| return self.data[item] | |||