| @@ -138,26 +138,12 @@ class JittorDriver(Driver): | |||||
| num_consumed_batches = states.pop('num_consumed_batches') | num_consumed_batches = states.pop('num_consumed_batches') | ||||
| if hasattr(sampler, 'state_dict') and callable(sampler.state_dict): | if hasattr(sampler, 'state_dict') and callable(sampler.state_dict): | ||||
| sampler_states = sampler.state_dict() | sampler_states = sampler.state_dict() | ||||
| # 需要针对 num_consumed_samples 做特殊的处理。因为DataLoader存在预取行为,直接使用sampler中的num_consumed_samples | |||||
| # 会造成多余实际消耗的问题。因为 | |||||
| num_consumed_samples_array = sampler_states.pop('num_consumed_samples_array', None) | |||||
| if num_consumed_samples_array is not None: | |||||
| if isinstance(sampler, ReproducibleSampler): # 如果是 sampler 的话,需要考虑 batch_size 。 | |||||
| if dataloader_args.batch_size is not None: | |||||
| num_consumed_batches = num_consumed_batches * dataloader_args.batch_size | |||||
| else: # 有可能 batch_size 为 None,就只有损失精度了 | |||||
| logger.rank_zero_warning("fastNLP cannot get batch_size, we have to save based on `num_consumed_samples`, " | |||||
| "it may cause missing some samples when reload.") | |||||
| num_consumed_batches = sampler_states['num_consumed_samples'] | |||||
| sampler_states['num_consumed_samples'] = num_consumed_samples_array[num_consumed_batches] | |||||
| assert sampler_states['num_consumed_samples'] != -1, "This is a bug, please report." | |||||
| if dataloader_args.batch_size is not None: | |||||
| sampler_states['num_consumed_samples'] = sampler.num_replicas * dataloader_args.batch_size \ | |||||
| * num_consumed_batches | |||||
| else: | else: | ||||
| if dataloader_args.batch_size is not None: | |||||
| sampler_states['num_consumed_samples'] = sampler.num_replicas * dataloader_args.batch_size \ | |||||
| * num_consumed_batches | |||||
| else: | |||||
| logger.rank_zero_warning("fastNLP cannot get batch_size, we have to save based on `num_consumed_samples`, " | |||||
| "it may cause missing some samples when reload.") | |||||
| logger.rank_zero_warning("fastNLP cannot get batch_size, we have to save based on `num_consumed_samples`, " | |||||
| "it may cause missing some samples when reload.") | |||||
| states['sampler_states'] = sampler_states | states['sampler_states'] = sampler_states | ||||
| else: | else: | ||||
| @@ -118,14 +118,14 @@ class JittorSingleDriver(JittorDriver): | |||||
| if args.sampler is None: | if args.sampler is None: | ||||
| sampler = RandomSampler(args.dataset, args.shuffle) | sampler = RandomSampler(args.dataset, args.shuffle) | ||||
| return replace_sampler(dataloader, sampler) | return replace_sampler(dataloader, sampler) | ||||
| elif isinstance(args.sampler, JittorRandomSampler): | |||||
| elif type(args.sampler) is JittorRandomSampler: | |||||
| if getattr(args.sampler, '_num_samples', None) is None \ | if getattr(args.sampler, '_num_samples', None) is None \ | ||||
| and getattr(args.sampler, 'rep', False) is False: | and getattr(args.sampler, 'rep', False) is False: | ||||
| # 如果本来就是随机的,并且没有定制,直接替换掉吧。 | # 如果本来就是随机的,并且没有定制,直接替换掉吧。 | ||||
| sampler = RandomSampler(args.sampler.dataset, shuffle=True) | sampler = RandomSampler(args.sampler.dataset, shuffle=True) | ||||
| logger.debug("Replace jittor RandomSampler into fastNLP RandomSampler.") | logger.debug("Replace jittor RandomSampler into fastNLP RandomSampler.") | ||||
| return replace_sampler(dataloader, sampler) | return replace_sampler(dataloader, sampler) | ||||
| elif isinstance(args.sampler, JittorSequentialSampler): | |||||
| elif type(args.sampler) is JittorSequentialSampler: | |||||
| # 需要替换为不要 shuffle 的。 | # 需要替换为不要 shuffle 的。 | ||||
| sampler = RandomSampler(args.sampler.dataset, shuffle=False) | sampler = RandomSampler(args.sampler.dataset, shuffle=False) | ||||
| logger.debug("Replace jittor SequentialSampler into fastNLP RandomSampler.") | logger.debug("Replace jittor SequentialSampler into fastNLP RandomSampler.") | ||||
| @@ -73,6 +73,7 @@ from .utils import ( | |||||
| _FleetWrappingModel, | _FleetWrappingModel, | ||||
| replace_sampler, | replace_sampler, | ||||
| replace_batch_sampler, | replace_batch_sampler, | ||||
| _check_dataloader_args_for_distributed | |||||
| ) | ) | ||||
| from .dist_utils import fastnlp_paddle_all_gather, fastnlp_paddle_broadcast_object | from .dist_utils import fastnlp_paddle_all_gather, fastnlp_paddle_broadcast_object | ||||
| @@ -453,6 +454,7 @@ class PaddleFleetDriver(PaddleDriver): | |||||
| ) | ) | ||||
| return replace_sampler(dataloader, sampler) | return replace_sampler(dataloader, sampler) | ||||
| else: | else: | ||||
| _check_dataloader_args_for_distributed(args, controller='Trainer') | |||||
| sampler = RandomSampler( | sampler = RandomSampler( | ||||
| dataset=args.dataset, | dataset=args.dataset, | ||||
| shuffle=args.shuffle, | shuffle=args.shuffle, | ||||
| @@ -222,26 +222,12 @@ class PaddleDriver(Driver): | |||||
| num_consumed_batches = states.pop("num_consumed_batches") | num_consumed_batches = states.pop("num_consumed_batches") | ||||
| if hasattr(sampler, "state_dict") and callable(sampler.state_dict): | if hasattr(sampler, "state_dict") and callable(sampler.state_dict): | ||||
| sampler_states = sampler.state_dict() | sampler_states = sampler.state_dict() | ||||
| # 如果有,需要针对 num_consumed_samples 做特殊的处理。因为DataLoader存在预取行为,直接使用sampler中的num_consumed_samples | |||||
| # 会造成多余实际消耗的问题。 | |||||
| num_consumed_samples_array = sampler_states.pop('num_consumed_samples_array', None) | |||||
| if num_consumed_samples_array is not None: | |||||
| if isinstance(sampler, ReproducibleSampler): # 如果是 sampler 的话,需要考虑 batch_size 。 | |||||
| if dataloader_args.batch_size is not None: | |||||
| num_consumed_batches = num_consumed_batches * dataloader_args.batch_size | |||||
| else: # 有可能 batch_size 为 None,就只有损失精度了 | |||||
| logger.rank_zero_warning("fastNLP cannot get batch_size, we have to save based on `num_consumed_samples`, " | |||||
| "it may cause missing some samples when reload.") | |||||
| num_consumed_batches = sampler_states['num_consumed_samples'] | |||||
| sampler_states['num_consumed_samples'] = num_consumed_samples_array[num_consumed_batches] | |||||
| assert sampler_states['num_consumed_samples'] != -1, "This is a bug, please report." | |||||
| if dataloader_args.batch_size is not None: | |||||
| sampler_states['num_consumed_samples'] = sampler.num_replicas * dataloader_args.batch_size \ | |||||
| * num_consumed_batches | |||||
| else: | else: | ||||
| if dataloader_args.batch_size is not None: | |||||
| sampler_states['num_consumed_samples'] = sampler.num_replicas * dataloader_args.batch_size \ | |||||
| * num_consumed_batches | |||||
| else: | |||||
| logger.rank_zero_warning("fastNLP cannot get batch_size, we have to save based on `num_consumed_samples`, " | |||||
| "it may cause missing some samples when reload.") | |||||
| logger.rank_zero_warning("fastNLP cannot get batch_size, we have to save based on `num_consumed_samples`, " | |||||
| "it may cause missing some samples when reload.") | |||||
| else: | else: | ||||
| raise RuntimeError( | raise RuntimeError( | ||||
| "The sampler has no `state_dict()` method, it will fail to recover to the specific batch.") | "The sampler has no `state_dict()` method, it will fail to recover to the specific batch.") | ||||
| @@ -26,6 +26,11 @@ if _NEED_IMPORT_PADDLE: | |||||
| import paddle | import paddle | ||||
| from paddle import DataParallel | from paddle import DataParallel | ||||
| from paddle.fluid.reader import _DatasetKind | from paddle.fluid.reader import _DatasetKind | ||||
| from paddle.io import ( | |||||
| RandomSampler as PaddleRandomSampler, | |||||
| SequenceSampler as PaddleSequenialSampler, | |||||
| BatchSampler as PaddleBatchSampler, | |||||
| ) | |||||
| __all__ = [ | __all__ = [ | ||||
| "PaddleSingleDriver", | "PaddleSingleDriver", | ||||
| @@ -122,19 +127,21 @@ class PaddleSingleDriver(PaddleDriver): | |||||
| return replace_sampler(dataloader, sampler) | return replace_sampler(dataloader, sampler) | ||||
| if reproducible: | if reproducible: | ||||
| if isinstance(args.sampler, paddle.io.RandomSampler): | |||||
| if getattr(args.sampler, '_num_samples', None) is None \ | |||||
| and getattr(args.sampler, 'replacements', False) is False \ | |||||
| and getattr(args.sampler, 'generator', None) is None: | |||||
| # 如果本来就是随机的,并且没有定制,直接替换掉。 | |||||
| sampler = RandomSampler(args.sampler.data_source, shuffle=True) | |||||
| logger.debug("Replace paddle RandomSampler into fastNLP RandomSampler.") | |||||
| if type(args.batch_sampler) is PaddleBatchSampler: | |||||
| if type(args.sampler) is PaddleRandomSampler: | |||||
| if isinstance(args.sampler, PaddleRandomSampler): | |||||
| if getattr(args.sampler, '_num_samples', None) is None \ | |||||
| and getattr(args.sampler, 'replacements', False) is False \ | |||||
| and getattr(args.sampler, 'generator', None) is None: | |||||
| # 如果本来就是随机的,并且没有定制,直接替换掉。 | |||||
| sampler = RandomSampler(args.sampler.data_source, shuffle=True) | |||||
| logger.debug("Replace paddle RandomSampler into fastNLP RandomSampler.") | |||||
| return replace_sampler(dataloader, sampler) | |||||
| elif type(args.sampler) is PaddleSequenialSampler: | |||||
| # 需要替换为不要 shuffle 的。 | |||||
| sampler = RandomSampler(args.sampler.data_source, shuffle=False) | |||||
| logger.debug("Replace paddle SequentialSampler into fastNLP RandomSampler.") | |||||
| return replace_sampler(dataloader, sampler) | return replace_sampler(dataloader, sampler) | ||||
| elif isinstance(args.sampler, paddle.io.SequenceSampler): | |||||
| # 需要替换为不要 shuffle 的。 | |||||
| sampler = RandomSampler(args.sampler.data_source, shuffle=False) | |||||
| logger.debug("Replace paddle SequentialSampler into fastNLP RandomSampler.") | |||||
| return replace_sampler(dataloader, sampler) | |||||
| batch_sampler = ReproduceBatchSampler( | batch_sampler = ReproduceBatchSampler( | ||||
| batch_sampler=args.batch_sampler, | batch_sampler=args.batch_sampler, | ||||
| batch_size=args.batch_size, | batch_size=args.batch_size, | ||||
| @@ -23,7 +23,7 @@ if _NEED_IMPORT_PADDLE: | |||||
| import paddle | import paddle | ||||
| from paddle import nn | from paddle import nn | ||||
| from paddle.nn import Layer | from paddle.nn import Layer | ||||
| from paddle.io import DataLoader, BatchSampler | |||||
| from paddle.io import DataLoader, BatchSampler, RandomSampler, SequenceSampler | |||||
| from paddle.amp import auto_cast, GradScaler | from paddle.amp import auto_cast, GradScaler | ||||
| else: | else: | ||||
| from fastNLP.core.utils.dummy_class import DummyClass as Layer | from fastNLP.core.utils.dummy_class import DummyClass as Layer | ||||
| @@ -249,3 +249,14 @@ def optimizer_state_to_device(state, device): | |||||
| else: | else: | ||||
| new_state[name] = param | new_state[name] = param | ||||
| return new_state | return new_state | ||||
| def _check_dataloader_args_for_distributed(args, controller='Trainer'): | |||||
| if type(args.batch_sampler) is not BatchSampler or (type(args.sampler) not in {RandomSampler, | |||||
| SequenceSampler}): | |||||
| mode = 'training' if controller == 'Trainer' else 'evaluation' | |||||
| substitution = 'fastNLP.RandomSampler' if controller == 'Trainer' else 'fastNLP.UnrepeatedSequentialSampler' | |||||
| raise TypeError(f"Using customized ``batch_sampler`` or ``sampler`` for distributed {mode} may cause " | |||||
| f"unpredictable problems, because fastNLP will substitute the dataloader's sampler into " | |||||
| f"``{substitution}``. The customized sampler should set for distributed running " | |||||
| f"before initializing ``{controller}`` , and then set the " | |||||
| f"parameter ``use_dist_sampler`` of ``{controller}`` to ``False``.") | |||||
| @@ -11,11 +11,12 @@ from fastNLP.core.samplers import ( | |||||
| ) | ) | ||||
| from tests.helpers.models.paddle_model import PaddleNormalModel_Classification_1 | from tests.helpers.models.paddle_model import PaddleNormalModel_Classification_1 | ||||
| from tests.helpers.datasets.paddle_data import PaddleNormalDataset, PaddleNormalXYDataset | from tests.helpers.datasets.paddle_data import PaddleNormalDataset, PaddleNormalXYDataset | ||||
| from tests.helpers.utils import magic_argv_env_context | |||||
| from tests.helpers.utils import magic_argv_env_context, recover_logger | |||||
| from fastNLP.envs.distributed import rank_zero_rm | from fastNLP.envs.distributed import rank_zero_rm | ||||
| from fastNLP import prepare_paddle_dataloader | from fastNLP import prepare_paddle_dataloader | ||||
| from fastNLP.core.drivers.paddle_driver.dist_utils import fastnlp_paddle_all_gather | from fastNLP.core.drivers.paddle_driver.dist_utils import fastnlp_paddle_all_gather | ||||
| from fastNLP.envs.imports import _NEED_IMPORT_PADDLE | from fastNLP.envs.imports import _NEED_IMPORT_PADDLE | ||||
| from fastNLP import logger | |||||
| if _NEED_IMPORT_PADDLE: | if _NEED_IMPORT_PADDLE: | ||||
| import paddle | import paddle | ||||
| import paddle.distributed as dist | import paddle.distributed as dist | ||||
| @@ -532,7 +533,6 @@ class TestSetDistReproDataloader: | |||||
| num_samples = 200 | num_samples = 200 | ||||
| dataset = PaddleNormalXYDataset(num_samples) | dataset = PaddleNormalXYDataset(num_samples) | ||||
| dl = prepare_paddle_dataloader(dataset, shuffle=shuffle, batch_size=batch_size, drop_last=drop_last) | dl = prepare_paddle_dataloader(dataset, shuffle=shuffle, batch_size=batch_size, drop_last=drop_last) | ||||
| model = PaddleNormalModel_Classification_1(10, 32) | |||||
| self.driver.setup() | self.driver.setup() | ||||
| dl = self.driver.set_dist_repro_dataloader(dataloader=dl, dist='dist', reproducible=reproducible) | dl = self.driver.set_dist_repro_dataloader(dataloader=dl, dist='dist', reproducible=reproducible) | ||||
| @@ -581,8 +581,6 @@ class TestSetDistReproDataloader: | |||||
| sampler = BucketedBatchSampler(dataset, length=dataset._data, batch_size=batch_size, drop_last=drop_last, | sampler = BucketedBatchSampler(dataset, length=dataset._data, batch_size=batch_size, drop_last=drop_last, | ||||
| shuffle=shuffle, num_batch_per_bucket=2) | shuffle=shuffle, num_batch_per_bucket=2) | ||||
| dl = prepare_paddle_dataloader(dataset, batch_sampler=sampler) | dl = prepare_paddle_dataloader(dataset, batch_sampler=sampler) | ||||
| model = PaddleNormalModel_Classification_1(10, 32) | |||||
| device = [0, 1] | |||||
| self.driver.setup() | self.driver.setup() | ||||
| dl = self.driver.set_dist_repro_dataloader(dataloader=dl, dist='dist', reproducible=reproducible) | dl = self.driver.set_dist_repro_dataloader(dataloader=dl, dist='dist', reproducible=reproducible) | ||||
| @@ -619,6 +617,95 @@ class TestSetDistReproDataloader: | |||||
| finally: | finally: | ||||
| dist.barrier() | dist.barrier() | ||||
| @magic_argv_env_context | |||||
| @recover_logger | |||||
| @pytest.mark.parametrize("inherit", ([True, False])) | |||||
| def test_customized_batch_sampler_dataloader(self, inherit): | |||||
| try: | |||||
| logger.set_stdout('raw', level='info') | |||||
| # 需要检验一下 set_dist_repro_dataloader 是否可以在定制 batch_sampler 的情况下正确运行 | |||||
| num_samples = 10 | |||||
| dataset = PaddleNormalXYDataset(num_samples) | |||||
| if inherit: | |||||
| class BatchSampler(paddle.io.BatchSampler): | |||||
| def __init__(self, dataset, batch_size): | |||||
| self.dataset = dataset | |||||
| self.batch_size = batch_size | |||||
| def __iter__(self): | |||||
| indices = list(range(len(dataset))) | |||||
| for i in range(len(self)): | |||||
| start = i * self.batch_size | |||||
| end = (i + 1) * self.batch_size | |||||
| return indices[start:end] | |||||
| def __len__(self): | |||||
| return (len(self.dataset)+self.batch_size-1)//self.batch_size | |||||
| else: | |||||
| class BatchSampler: | |||||
| def __init__(self, dataset, batch_size): | |||||
| self.dataset = dataset | |||||
| self.batch_size = batch_size | |||||
| def __iter__(self): | |||||
| indices = list(range(len(dataset))) | |||||
| for i in range(len(self)): | |||||
| start = i * self.batch_size | |||||
| end = (i + 1) * self.batch_size | |||||
| return indices[start:end] | |||||
| def __len__(self): | |||||
| return (len(self.dataset)+self.batch_size-1)//self.batch_size | |||||
| dl = prepare_paddle_dataloader(dataset, batch_sampler=BatchSampler(dataset, batch_size=4)) | |||||
| self.driver.setup() | |||||
| with pytest.raises(TypeError): | |||||
| dl = self.driver.set_dist_repro_dataloader(dataloader=dl, dist='dist', reproducible=False) | |||||
| finally: | |||||
| pass | |||||
| @magic_argv_env_context | |||||
| @recover_logger | |||||
| @pytest.mark.parametrize("inherit", ([True, False])) | |||||
| def test_customized_sampler_dataloader(self, inherit): | |||||
| try: | |||||
| logger.set_stdout('raw', level='info') | |||||
| # 需要检验一下 set_dist_repro_dataloader 是否可以在定制 batch_sampler 的情况下正确运行 | |||||
| num_samples = 10 | |||||
| dataset = PaddleNormalXYDataset(num_samples) | |||||
| if inherit: | |||||
| class Sampler(paddle.io.RandomSampler): | |||||
| def __init__(self, dataset, batch_size): | |||||
| self.dataset = dataset | |||||
| self.batch_size = batch_size | |||||
| def __iter__(self): | |||||
| indices = list(range(len(dataset))) | |||||
| return iter(indices) | |||||
| def __len__(self): | |||||
| return len(self.dataset) | |||||
| else: | |||||
| class Sampler: | |||||
| def __init__(self, dataset, batch_size): | |||||
| self.dataset = dataset | |||||
| self.batch_size = batch_size | |||||
| def __iter__(self): | |||||
| indices = list(range(len(dataset))) | |||||
| return iter(indices) | |||||
| def __len__(self): | |||||
| return len(self.dataset) | |||||
| dl = prepare_paddle_dataloader(dataset, sampler=Sampler(dataset, batch_size=4)) | |||||
| self.driver.setup() | |||||
| # TODO 这里需要raise | |||||
| with pytest.raises(TypeError): | |||||
| dl = self.driver.set_dist_repro_dataloader(dataloader=dl, dist='dist', reproducible=False) | |||||
| finally: | |||||
| pass | |||||
| ############################################################################ | ############################################################################ | ||||
| # | # | ||||