| @@ -138,26 +138,12 @@ class JittorDriver(Driver): | |||
| num_consumed_batches = states.pop('num_consumed_batches') | |||
| if hasattr(sampler, 'state_dict') and callable(sampler.state_dict): | |||
| sampler_states = sampler.state_dict() | |||
| # 需要针对 num_consumed_samples 做特殊的处理。因为DataLoader存在预取行为,直接使用sampler中的num_consumed_samples | |||
| # 会造成多余实际消耗的问题。因为 | |||
| num_consumed_samples_array = sampler_states.pop('num_consumed_samples_array', None) | |||
| if num_consumed_samples_array is not None: | |||
| if isinstance(sampler, ReproducibleSampler): # 如果是 sampler 的话,需要考虑 batch_size 。 | |||
| if dataloader_args.batch_size is not None: | |||
| num_consumed_batches = num_consumed_batches * dataloader_args.batch_size | |||
| else: # 有可能 batch_size 为 None,就只有损失精度了 | |||
| logger.rank_zero_warning("fastNLP cannot get batch_size, we have to save based on `num_consumed_samples`, " | |||
| "it may cause missing some samples when reload.") | |||
| num_consumed_batches = sampler_states['num_consumed_samples'] | |||
| sampler_states['num_consumed_samples'] = num_consumed_samples_array[num_consumed_batches] | |||
| assert sampler_states['num_consumed_samples'] != -1, "This is a bug, please report." | |||
| if dataloader_args.batch_size is not None: | |||
| sampler_states['num_consumed_samples'] = sampler.num_replicas * dataloader_args.batch_size \ | |||
| * num_consumed_batches | |||
| else: | |||
| if dataloader_args.batch_size is not None: | |||
| sampler_states['num_consumed_samples'] = sampler.num_replicas * dataloader_args.batch_size \ | |||
| * num_consumed_batches | |||
| else: | |||
| logger.rank_zero_warning("fastNLP cannot get batch_size, we have to save based on `num_consumed_samples`, " | |||
| "it may cause missing some samples when reload.") | |||
| logger.rank_zero_warning("fastNLP cannot get batch_size, we have to save based on `num_consumed_samples`, " | |||
| "it may cause missing some samples when reload.") | |||
| states['sampler_states'] = sampler_states | |||
| else: | |||
| @@ -118,14 +118,14 @@ class JittorSingleDriver(JittorDriver): | |||
| if args.sampler is None: | |||
| sampler = RandomSampler(args.dataset, args.shuffle) | |||
| return replace_sampler(dataloader, sampler) | |||
| elif isinstance(args.sampler, JittorRandomSampler): | |||
| elif type(args.sampler) is JittorRandomSampler: | |||
| if getattr(args.sampler, '_num_samples', None) is None \ | |||
| and getattr(args.sampler, 'rep', False) is False: | |||
| # 如果本来就是随机的,并且没有定制,直接替换掉吧。 | |||
| sampler = RandomSampler(args.sampler.dataset, shuffle=True) | |||
| logger.debug("Replace jittor RandomSampler into fastNLP RandomSampler.") | |||
| return replace_sampler(dataloader, sampler) | |||
| elif isinstance(args.sampler, JittorSequentialSampler): | |||
| elif type(args.sampler) is JittorSequentialSampler: | |||
| # 需要替换为不要 shuffle 的。 | |||
| sampler = RandomSampler(args.sampler.dataset, shuffle=False) | |||
| logger.debug("Replace jittor SequentialSampler into fastNLP RandomSampler.") | |||
| @@ -73,6 +73,7 @@ from .utils import ( | |||
| _FleetWrappingModel, | |||
| replace_sampler, | |||
| replace_batch_sampler, | |||
| _check_dataloader_args_for_distributed | |||
| ) | |||
| from .dist_utils import fastnlp_paddle_all_gather, fastnlp_paddle_broadcast_object | |||
| @@ -453,6 +454,7 @@ class PaddleFleetDriver(PaddleDriver): | |||
| ) | |||
| return replace_sampler(dataloader, sampler) | |||
| else: | |||
| _check_dataloader_args_for_distributed(args, controller='Trainer') | |||
| sampler = RandomSampler( | |||
| dataset=args.dataset, | |||
| shuffle=args.shuffle, | |||
| @@ -222,26 +222,12 @@ class PaddleDriver(Driver): | |||
| num_consumed_batches = states.pop("num_consumed_batches") | |||
| if hasattr(sampler, "state_dict") and callable(sampler.state_dict): | |||
| sampler_states = sampler.state_dict() | |||
| # 如果有,需要针对 num_consumed_samples 做特殊的处理。因为DataLoader存在预取行为,直接使用sampler中的num_consumed_samples | |||
| # 会造成多余实际消耗的问题。 | |||
| num_consumed_samples_array = sampler_states.pop('num_consumed_samples_array', None) | |||
| if num_consumed_samples_array is not None: | |||
| if isinstance(sampler, ReproducibleSampler): # 如果是 sampler 的话,需要考虑 batch_size 。 | |||
| if dataloader_args.batch_size is not None: | |||
| num_consumed_batches = num_consumed_batches * dataloader_args.batch_size | |||
| else: # 有可能 batch_size 为 None,就只有损失精度了 | |||
| logger.rank_zero_warning("fastNLP cannot get batch_size, we have to save based on `num_consumed_samples`, " | |||
| "it may cause missing some samples when reload.") | |||
| num_consumed_batches = sampler_states['num_consumed_samples'] | |||
| sampler_states['num_consumed_samples'] = num_consumed_samples_array[num_consumed_batches] | |||
| assert sampler_states['num_consumed_samples'] != -1, "This is a bug, please report." | |||
| if dataloader_args.batch_size is not None: | |||
| sampler_states['num_consumed_samples'] = sampler.num_replicas * dataloader_args.batch_size \ | |||
| * num_consumed_batches | |||
| else: | |||
| if dataloader_args.batch_size is not None: | |||
| sampler_states['num_consumed_samples'] = sampler.num_replicas * dataloader_args.batch_size \ | |||
| * num_consumed_batches | |||
| else: | |||
| logger.rank_zero_warning("fastNLP cannot get batch_size, we have to save based on `num_consumed_samples`, " | |||
| "it may cause missing some samples when reload.") | |||
| logger.rank_zero_warning("fastNLP cannot get batch_size, we have to save based on `num_consumed_samples`, " | |||
| "it may cause missing some samples when reload.") | |||
| else: | |||
| raise RuntimeError( | |||
| "The sampler has no `state_dict()` method, it will fail to recover to the specific batch.") | |||
| @@ -26,6 +26,11 @@ if _NEED_IMPORT_PADDLE: | |||
| import paddle | |||
| from paddle import DataParallel | |||
| from paddle.fluid.reader import _DatasetKind | |||
| from paddle.io import ( | |||
| RandomSampler as PaddleRandomSampler, | |||
| SequenceSampler as PaddleSequenialSampler, | |||
| BatchSampler as PaddleBatchSampler, | |||
| ) | |||
| __all__ = [ | |||
| "PaddleSingleDriver", | |||
| @@ -122,19 +127,21 @@ class PaddleSingleDriver(PaddleDriver): | |||
| return replace_sampler(dataloader, sampler) | |||
| if reproducible: | |||
| if isinstance(args.sampler, paddle.io.RandomSampler): | |||
| if getattr(args.sampler, '_num_samples', None) is None \ | |||
| and getattr(args.sampler, 'replacements', False) is False \ | |||
| and getattr(args.sampler, 'generator', None) is None: | |||
| # 如果本来就是随机的,并且没有定制,直接替换掉。 | |||
| sampler = RandomSampler(args.sampler.data_source, shuffle=True) | |||
| logger.debug("Replace paddle RandomSampler into fastNLP RandomSampler.") | |||
| if type(args.batch_sampler) is PaddleBatchSampler: | |||
| if type(args.sampler) is PaddleRandomSampler: | |||
| if isinstance(args.sampler, PaddleRandomSampler): | |||
| if getattr(args.sampler, '_num_samples', None) is None \ | |||
| and getattr(args.sampler, 'replacements', False) is False \ | |||
| and getattr(args.sampler, 'generator', None) is None: | |||
| # 如果本来就是随机的,并且没有定制,直接替换掉。 | |||
| sampler = RandomSampler(args.sampler.data_source, shuffle=True) | |||
| logger.debug("Replace paddle RandomSampler into fastNLP RandomSampler.") | |||
| return replace_sampler(dataloader, sampler) | |||
| elif type(args.sampler) is PaddleSequenialSampler: | |||
| # 需要替换为不要 shuffle 的。 | |||
| sampler = RandomSampler(args.sampler.data_source, shuffle=False) | |||
| logger.debug("Replace paddle SequentialSampler into fastNLP RandomSampler.") | |||
| return replace_sampler(dataloader, sampler) | |||
| elif isinstance(args.sampler, paddle.io.SequenceSampler): | |||
| # 需要替换为不要 shuffle 的。 | |||
| sampler = RandomSampler(args.sampler.data_source, shuffle=False) | |||
| logger.debug("Replace paddle SequentialSampler into fastNLP RandomSampler.") | |||
| return replace_sampler(dataloader, sampler) | |||
| batch_sampler = ReproduceBatchSampler( | |||
| batch_sampler=args.batch_sampler, | |||
| batch_size=args.batch_size, | |||
| @@ -23,7 +23,7 @@ if _NEED_IMPORT_PADDLE: | |||
| import paddle | |||
| from paddle import nn | |||
| from paddle.nn import Layer | |||
| from paddle.io import DataLoader, BatchSampler | |||
| from paddle.io import DataLoader, BatchSampler, RandomSampler, SequenceSampler | |||
| from paddle.amp import auto_cast, GradScaler | |||
| else: | |||
| from fastNLP.core.utils.dummy_class import DummyClass as Layer | |||
| @@ -249,3 +249,14 @@ def optimizer_state_to_device(state, device): | |||
| else: | |||
| new_state[name] = param | |||
| return new_state | |||
| def _check_dataloader_args_for_distributed(args, controller='Trainer'): | |||
| if type(args.batch_sampler) is not BatchSampler or (type(args.sampler) not in {RandomSampler, | |||
| SequenceSampler}): | |||
| mode = 'training' if controller == 'Trainer' else 'evaluation' | |||
| substitution = 'fastNLP.RandomSampler' if controller == 'Trainer' else 'fastNLP.UnrepeatedSequentialSampler' | |||
| raise TypeError(f"Using customized ``batch_sampler`` or ``sampler`` for distributed {mode} may cause " | |||
| f"unpredictable problems, because fastNLP will substitute the dataloader's sampler into " | |||
| f"``{substitution}``. The customized sampler should set for distributed running " | |||
| f"before initializing ``{controller}`` , and then set the " | |||
| f"parameter ``use_dist_sampler`` of ``{controller}`` to ``False``.") | |||
| @@ -11,11 +11,12 @@ from fastNLP.core.samplers import ( | |||
| ) | |||
| from tests.helpers.models.paddle_model import PaddleNormalModel_Classification_1 | |||
| from tests.helpers.datasets.paddle_data import PaddleNormalDataset, PaddleNormalXYDataset | |||
| from tests.helpers.utils import magic_argv_env_context | |||
| from tests.helpers.utils import magic_argv_env_context, recover_logger | |||
| from fastNLP.envs.distributed import rank_zero_rm | |||
| from fastNLP import prepare_paddle_dataloader | |||
| from fastNLP.core.drivers.paddle_driver.dist_utils import fastnlp_paddle_all_gather | |||
| from fastNLP.envs.imports import _NEED_IMPORT_PADDLE | |||
| from fastNLP import logger | |||
| if _NEED_IMPORT_PADDLE: | |||
| import paddle | |||
| import paddle.distributed as dist | |||
| @@ -532,7 +533,6 @@ class TestSetDistReproDataloader: | |||
| num_samples = 200 | |||
| dataset = PaddleNormalXYDataset(num_samples) | |||
| dl = prepare_paddle_dataloader(dataset, shuffle=shuffle, batch_size=batch_size, drop_last=drop_last) | |||
| model = PaddleNormalModel_Classification_1(10, 32) | |||
| self.driver.setup() | |||
| dl = self.driver.set_dist_repro_dataloader(dataloader=dl, dist='dist', reproducible=reproducible) | |||
| @@ -581,8 +581,6 @@ class TestSetDistReproDataloader: | |||
| sampler = BucketedBatchSampler(dataset, length=dataset._data, batch_size=batch_size, drop_last=drop_last, | |||
| shuffle=shuffle, num_batch_per_bucket=2) | |||
| dl = prepare_paddle_dataloader(dataset, batch_sampler=sampler) | |||
| model = PaddleNormalModel_Classification_1(10, 32) | |||
| device = [0, 1] | |||
| self.driver.setup() | |||
| dl = self.driver.set_dist_repro_dataloader(dataloader=dl, dist='dist', reproducible=reproducible) | |||
| @@ -619,6 +617,95 @@ class TestSetDistReproDataloader: | |||
| finally: | |||
| dist.barrier() | |||
| @magic_argv_env_context | |||
| @recover_logger | |||
| @pytest.mark.parametrize("inherit", ([True, False])) | |||
| def test_customized_batch_sampler_dataloader(self, inherit): | |||
| try: | |||
| logger.set_stdout('raw', level='info') | |||
| # 需要检验一下 set_dist_repro_dataloader 是否可以在定制 batch_sampler 的情况下正确运行 | |||
| num_samples = 10 | |||
| dataset = PaddleNormalXYDataset(num_samples) | |||
| if inherit: | |||
| class BatchSampler(paddle.io.BatchSampler): | |||
| def __init__(self, dataset, batch_size): | |||
| self.dataset = dataset | |||
| self.batch_size = batch_size | |||
| def __iter__(self): | |||
| indices = list(range(len(dataset))) | |||
| for i in range(len(self)): | |||
| start = i * self.batch_size | |||
| end = (i + 1) * self.batch_size | |||
| return indices[start:end] | |||
| def __len__(self): | |||
| return (len(self.dataset)+self.batch_size-1)//self.batch_size | |||
| else: | |||
| class BatchSampler: | |||
| def __init__(self, dataset, batch_size): | |||
| self.dataset = dataset | |||
| self.batch_size = batch_size | |||
| def __iter__(self): | |||
| indices = list(range(len(dataset))) | |||
| for i in range(len(self)): | |||
| start = i * self.batch_size | |||
| end = (i + 1) * self.batch_size | |||
| return indices[start:end] | |||
| def __len__(self): | |||
| return (len(self.dataset)+self.batch_size-1)//self.batch_size | |||
| dl = prepare_paddle_dataloader(dataset, batch_sampler=BatchSampler(dataset, batch_size=4)) | |||
| self.driver.setup() | |||
| with pytest.raises(TypeError): | |||
| dl = self.driver.set_dist_repro_dataloader(dataloader=dl, dist='dist', reproducible=False) | |||
| finally: | |||
| pass | |||
| @magic_argv_env_context | |||
| @recover_logger | |||
| @pytest.mark.parametrize("inherit", ([True, False])) | |||
| def test_customized_sampler_dataloader(self, inherit): | |||
| try: | |||
| logger.set_stdout('raw', level='info') | |||
| # 需要检验一下 set_dist_repro_dataloader 是否可以在定制 batch_sampler 的情况下正确运行 | |||
| num_samples = 10 | |||
| dataset = PaddleNormalXYDataset(num_samples) | |||
| if inherit: | |||
| class Sampler(paddle.io.RandomSampler): | |||
| def __init__(self, dataset, batch_size): | |||
| self.dataset = dataset | |||
| self.batch_size = batch_size | |||
| def __iter__(self): | |||
| indices = list(range(len(dataset))) | |||
| return iter(indices) | |||
| def __len__(self): | |||
| return len(self.dataset) | |||
| else: | |||
| class Sampler: | |||
| def __init__(self, dataset, batch_size): | |||
| self.dataset = dataset | |||
| self.batch_size = batch_size | |||
| def __iter__(self): | |||
| indices = list(range(len(dataset))) | |||
| return iter(indices) | |||
| def __len__(self): | |||
| return len(self.dataset) | |||
| dl = prepare_paddle_dataloader(dataset, sampler=Sampler(dataset, batch_size=4)) | |||
| self.driver.setup() | |||
| # TODO 这里需要raise | |||
| with pytest.raises(TypeError): | |||
| dl = self.driver.set_dist_repro_dataloader(dataloader=dl, dist='dist', reproducible=False) | |||
| finally: | |||
| pass | |||
| ############################################################################ | |||
| # | |||