| @@ -123,17 +123,24 @@ class PaddleSingleDriver(PaddleDriver): | |||||
| if reproducible: | if reproducible: | ||||
| if isinstance(args.sampler, paddle.io.RandomSampler): | if isinstance(args.sampler, paddle.io.RandomSampler): | ||||
| # 如果本来就是随机的,直接替换 | |||||
| sampler = RandomSampler(args.sampler.data_source) | |||||
| logger.debug("Replace paddle RandomSampler into fastNLP RandomSampler.") | |||||
| if getattr(args.sampler, '_num_samples', None) is None \ | |||||
| and getattr(args.sampler, 'replacements', False) is False \ | |||||
| and getattr(args.sampler, 'generator', None) is None: | |||||
| # 如果本来就是随机的,并且没有定制,直接替换掉。 | |||||
| sampler = RandomSampler(args.sampler.data_source, shuffle=True) | |||||
| logger.debug("Replace paddle RandomSampler into fastNLP RandomSampler.") | |||||
| return replace_sampler(dataloader, sampler) | |||||
| elif isinstance(args.sampler, paddle.io.SequenceSampler): | |||||
| # 需要替换为不要 shuffle 的。 | |||||
| sampler = RandomSampler(args.sampler.data_source, shuffle=False) | |||||
| logger.debug("Replace paddle SequentialSampler into fastNLP RandomSampler.") | |||||
| return replace_sampler(dataloader, sampler) | return replace_sampler(dataloader, sampler) | ||||
| else: | |||||
| batch_sampler = ReproduceBatchSampler( | |||||
| batch_sampler=args.batch_sampler, | |||||
| batch_size=args.batch_size, | |||||
| drop_last=args.drop_last | |||||
| ) | |||||
| return replace_batch_sampler(dataloader, batch_sampler) | |||||
| batch_sampler = ReproduceBatchSampler( | |||||
| batch_sampler=args.batch_sampler, | |||||
| batch_size=args.batch_size, | |||||
| drop_last=args.drop_last | |||||
| ) | |||||
| return replace_batch_sampler(dataloader, batch_sampler) | |||||
| else: | else: | ||||
| return dataloader | return dataloader | ||||
| @@ -250,7 +250,7 @@ def test_trainer_output_from_new_proc( | |||||
| @pytest.mark.torch | @pytest.mark.torch | ||||
| @pytest.mark.parametrize("driver,device", [("torch", [1, 2])]) | |||||
| @pytest.mark.parametrize("driver,device", [("torch", [0, 1])]) | |||||
| @pytest.mark.parametrize("cur_rank", [0]) # 依次测试如果是当前进程出现错误,是否能够正确地 kill 掉其他进程; , 1, 2, 3 | @pytest.mark.parametrize("cur_rank", [0]) # 依次测试如果是当前进程出现错误,是否能够正确地 kill 掉其他进程; , 1, 2, 3 | ||||
| @magic_argv_env_context | @magic_argv_env_context | ||||
| def test_trainer_on_exception( | def test_trainer_on_exception( | ||||
| @@ -386,22 +386,16 @@ class TestSetDistReproDataloader: | |||||
| def test_with_reproducible_true(self, shuffle): | def test_with_reproducible_true(self, shuffle): | ||||
| """ | """ | ||||
| 测试 set_dist_repro_dataloader 参数 `reproducible` 为 True 时的表现 | 测试 set_dist_repro_dataloader 参数 `reproducible` 为 True 时的表现 | ||||
| 当dist为字符串时,此时应该返回新的 dataloader,且如果原 sampler 为 paddle.io.RandomSampler(shuffle=True), | |||||
| 只会替换 Sampler 为 RandomSampler;否则会替换 batch_sampler 为 ReproduceBatchSampler | |||||
| 当dist为字符串时,此时应该返回新的 dataloader,会替换 sampler 为 RandomSampler | |||||
| """ | """ | ||||
| dataloader = DataLoader(self.dataset, batch_size=2, shuffle=shuffle) | dataloader = DataLoader(self.dataset, batch_size=2, shuffle=shuffle) | ||||
| replaced_loader = self.driver.set_dist_repro_dataloader(dataloader, dist="dist", reproducible=True) | replaced_loader = self.driver.set_dist_repro_dataloader(dataloader, dist="dist", reproducible=True) | ||||
| assert not (replaced_loader is dataloader) | assert not (replaced_loader is dataloader) | ||||
| if shuffle: | |||||
| # 此时会替换 sampler | |||||
| assert isinstance(replaced_loader.batch_sampler, paddle.io.BatchSampler) | |||||
| assert not (replaced_loader.batch_sampler is dataloader.batch_sampler) | |||||
| assert isinstance(replaced_loader.batch_sampler.sampler, RandomSampler) | |||||
| else: | |||||
| # 此时会替换 batch_sampler | |||||
| assert isinstance(replaced_loader.batch_sampler, ReproduceBatchSampler) | |||||
| assert isinstance(replaced_loader.batch_sampler.batch_sampler, BatchSampler) | |||||
| assert isinstance(replaced_loader.batch_sampler, paddle.io.BatchSampler) | |||||
| assert not (replaced_loader.batch_sampler is dataloader.batch_sampler) | |||||
| assert isinstance(replaced_loader.batch_sampler.sampler, RandomSampler) | |||||
| assert replaced_loader.batch_sampler.sampler.shuffle == shuffle | |||||
| assert replaced_loader.batch_sampler.batch_size == dataloader.batch_sampler.batch_size | assert replaced_loader.batch_sampler.batch_size == dataloader.batch_sampler.batch_size | ||||
| assert replaced_loader.drop_last == dataloader.drop_last | assert replaced_loader.drop_last == dataloader.drop_last | ||||
| @@ -400,22 +400,19 @@ class TestSetDistReproDataloader: | |||||
| def test_with_reproducible_true(self, shuffle): | def test_with_reproducible_true(self, shuffle): | ||||
| """ | """ | ||||
| 测试 set_dist_repro_dataloader 参数 `reproducible` 为 True 时的表现 | 测试 set_dist_repro_dataloader 参数 `reproducible` 为 True 时的表现 | ||||
| 当dist为字符串时,此时应该返回新的 dataloader,且如果原 sampler 为 torch.utils.data.RandomSampler(shuffle=True), | |||||
| 只会替换 Sampler 为 RandomSampler;否则会替换 batch_sampler 为 ReproduceBatchSampler | |||||
| 当dist为字符串时,此时应该返回新的 dataloader,会替换 sampler 为 RandomSampler; | |||||
| TODO: | |||||
| 在 Sampler 的参数不是默认的情况下会替换 batch_sampler | |||||
| """ | """ | ||||
| dataloader = DataLoader(self.dataset, batch_size=2, shuffle=shuffle) | dataloader = DataLoader(self.dataset, batch_size=2, shuffle=shuffle) | ||||
| replaced_loader = self.driver.set_dist_repro_dataloader(dataloader, dist="dist", reproducible=True) | replaced_loader = self.driver.set_dist_repro_dataloader(dataloader, dist="dist", reproducible=True) | ||||
| assert not (replaced_loader is dataloader) | assert not (replaced_loader is dataloader) | ||||
| if shuffle: | |||||
| # 此时会替换 sampler | |||||
| assert isinstance(replaced_loader.batch_sampler, torch.utils.data.BatchSampler) | |||||
| assert not (replaced_loader.batch_sampler is dataloader.batch_sampler) | |||||
| assert isinstance(replaced_loader.batch_sampler.sampler, RandomSampler) | |||||
| else: | |||||
| # 此时会替换 batch_sampler | |||||
| assert isinstance(replaced_loader.batch_sampler, ReproduceBatchSampler) | |||||
| assert isinstance(replaced_loader.batch_sampler.batch_sampler, BatchSampler) | |||||
| # 替换 sampler | |||||
| assert isinstance(replaced_loader.batch_sampler, torch.utils.data.BatchSampler) | |||||
| assert not (replaced_loader.batch_sampler is dataloader.batch_sampler) | |||||
| assert isinstance(replaced_loader.batch_sampler.sampler, RandomSampler) | |||||
| assert replaced_loader.batch_sampler.sampler.shuffle == shuffle | |||||
| assert replaced_loader.batch_sampler.batch_size == dataloader.batch_sampler.batch_size | assert replaced_loader.batch_sampler.batch_size == dataloader.batch_sampler.batch_size | ||||
| assert replaced_loader.drop_last == dataloader.drop_last | assert replaced_loader.drop_last == dataloader.drop_last | ||||