| @@ -23,7 +23,6 @@ from fastNLP.core.drivers import Driver | |||
| from fastNLP.core.drivers.utils import choose_driver | |||
| from fastNLP.core.utils import check_fn_not_empty_params, get_fn_arg_names, match_and_substitute_params, nullcontext | |||
| from fastNLP.envs import rank_zero_call | |||
| from fastNLP.core.samplers import ReproducibleSampler, RandomBatchSampler | |||
| from fastNLP.core.log import logger | |||
| from fastNLP.envs import FASTNLP_MODEL_FILENAME | |||
| @@ -49,13 +49,13 @@ class Driver(ABC): | |||
| 不同 gpu 上出现重复;为 'unrepeatdist' 时,表示该 dataloader 应该保证所有 gpu 上迭代出来的数据合并起来应该刚好等于原始的 | |||
| 数据,允许不同 gpu 上 batch 的数量不一致。其中 trainer 中 kwargs 的参数 `use_dist_sampler` 为 True 时,该值为 "dist"; | |||
| 否则为 None ,evaluator 中的 kwargs 的参数 `use_dist_sampler` 为 True 时,该值为 "unrepeatdist",否则为 None; | |||
| 注意当 dist 为 ReproducibleIterator, RandomBatchSampler 时,是断点重训加载时 driver.load 函数在调用; | |||
| 注意当 dist 为 ReproducibleSampler, ReproducibleBatchSampler 时,是断点重训加载时 driver.load 函数在调用; | |||
| 当 dist 为 str 或者 None 时,是 trainer 在初始化时调用该函数; | |||
| :param reproducible: 如果为 False ,不要做任何考虑;如果为 True ,需要保证返回的 dataloader 可以保存当前的迭代状态,使得 | |||
| 可以可以加载。 | |||
| :return: 应当返回一个被替换 sampler 后的新的 dataloader 对象 (注意此处一定需要返回一个新的 dataloader 对象) ;此外, | |||
| 如果传入的 dataloader 中是 ReproducibleSampler 或者 RandomBatchSampler 需要重新初始化一个放入返回的 | |||
| 如果传入的 dataloader 中是 ReproducibleSampler 或者 ReproducibleBatchSampler 需要重新初始化一个放入返回的 | |||
| dataloader 中。如果 dist 为空,且 reproducible 为 False,可直接返回原对象。 | |||
| """ | |||
| if dist is None and reproducible is False: | |||
| @@ -3,7 +3,7 @@ from typing import Dict, Union | |||
| from .jittor_driver import JittorDriver | |||
| from fastNLP.core.utils import auto_param_call | |||
| from fastNLP.envs.imports import _NEED_IMPORT_JITTOR | |||
| from fastNLP.core.samplers import RandomBatchSampler, ReproducibleSampler | |||
| from fastNLP.core.samplers import ReproducibleBatchSampler, ReproducibleSampler | |||
| if _NEED_IMPORT_JITTOR: | |||
| import jittor | |||
| @@ -99,10 +99,10 @@ class JittorSingleDriver(JittorDriver): | |||
| def is_distributed(self): | |||
| return False | |||
| def set_dist_repro_dataloader(self, dataloader, dist: Union[str, RandomBatchSampler, ReproducibleSampler], | |||
| def set_dist_repro_dataloader(self, dataloader, dist: Union[str, ReproducibleBatchSampler, ReproducibleSampler], | |||
| reproducible: bool = False, sampler_or_batch_sampler=None): | |||
| # reproducible 的相关功能暂时没有实现 | |||
| if isinstance(dist, RandomBatchSampler): | |||
| if isinstance(dist, ReproducibleBatchSampler): | |||
| raise NotImplementedError | |||
| dataloader.batch_sampler = dist_sample | |||
| if isinstance(dist, ReproducibleSampler): | |||
| @@ -10,7 +10,7 @@ from fastNLP.core.utils import ( | |||
| get_paddle_device_id, | |||
| paddle_move_data_to_device, | |||
| ) | |||
| from fastNLP.core.samplers import RandomBatchSampler, ReproducibleSampler | |||
| from fastNLP.core.samplers import ReproducibleBatchSampler, ReproducibleSampler | |||
| from fastNLP.core.log import logger | |||
| if _NEED_IMPORT_PADDLE: | |||
| @@ -139,12 +139,12 @@ class PaddleSingleDriver(PaddleDriver): | |||
| """ | |||
| return paddle_move_data_to_device(batch, "gpu:0") | |||
| def set_dist_repro_dataloader(self, dataloader, dist: Union[str, RandomBatchSampler, ReproducibleSampler], | |||
| def set_dist_repro_dataloader(self, dataloader, dist: Union[str, ReproducibleBatchSampler, ReproducibleSampler], | |||
| reproducible: bool = False, sampler_or_batch_sampler=None): | |||
| # 暂时不支持IteratorDataset | |||
| assert dataloader.dataset_kind != _DatasetKind.ITER, \ | |||
| "FastNLP does not support `IteratorDataset` now." | |||
| if isinstance(dist, RandomBatchSampler): | |||
| if isinstance(dist, ReproducibleBatchSampler): | |||
| dataloader.batch_sampler = dist | |||
| return dataloader | |||
| if isinstance(dist, ReproducibleSampler): | |||
| @@ -154,11 +154,11 @@ class PaddleSingleDriver(PaddleDriver): | |||
| if reproducible: | |||
| if isinstance(dataloader.batch_sampler.sampler, ReproducibleSampler): | |||
| return dataloader | |||
| elif isinstance(dataloader.batch_sampler, RandomBatchSampler): | |||
| elif isinstance(dataloader.batch_sampler, ReproducibleBatchSampler): | |||
| return dataloader | |||
| else: | |||
| # TODO | |||
| batch_sampler = RandomBatchSampler( | |||
| batch_sampler = ReproducibleBatchSampler( | |||
| batch_sampler=dataloader.batch_sampler, | |||
| batch_size=dataloader.batch_sampler.batch_size, | |||
| drop_last=dataloader.drop_last | |||
| @@ -28,7 +28,7 @@ from fastNLP.core.drivers.torch_driver.utils import ( | |||
| ) | |||
| from fastNLP.core.drivers.utils import distributed_open_proc | |||
| from fastNLP.core.utils import auto_param_call, check_user_specific_params | |||
| from fastNLP.core.samplers import ReproducibleSampler, RandomSampler, UnrepeatedSequentialSampler, RandomBatchSampler, \ | |||
| from fastNLP.core.samplers import ReproducibleSampler, RandomSampler, UnrepeatedSequentialSampler, ReproducibleBatchSampler, \ | |||
| re_instantiate_sampler, UnrepeatedSampler, conversion_between_reproducible_and_unrepeated_sampler | |||
| from fastNLP.envs import FASTNLP_DISTRIBUTED_CHECK, FASTNLP_GLOBAL_RANK, FASTNLP_GLOBAL_SEED | |||
| from fastNLP.core.log import logger | |||
| @@ -446,11 +446,11 @@ class TorchDDPDriver(TorchDriver): | |||
| # return self.model(batch, **{_MODE_PARAMETER: ForwardState.TEST}) | |||
| return self._test_step(batch) | |||
| def set_dist_repro_dataloader(self, dataloader, dist: Optional[Union[str, ReproducibleSampler, RandomBatchSampler]]=None, | |||
| def set_dist_repro_dataloader(self, dataloader, dist: Optional[Union[str, ReproducibleSampler, ReproducibleBatchSampler]]=None, | |||
| reproducible: bool = False): | |||
| # 如果 dist 为 RandomBatchSampler, ReproducibleIterator 说明是在断点重训时 driver.load 函数调用; | |||
| # 如果 dist 为 ReproducibleBatchSampler, ReproducibleSampler 说明是在断点重训时 driver.load 函数调用; | |||
| # 注意这里不需要调用 dist_sampler.set_distributed;因为如果用户使用的是 TorchDDPDriver,那么其在 Trainer 初始化的时候就已经调用了该函数; | |||
| if isinstance(dist, RandomBatchSampler): | |||
| if isinstance(dist, ReproducibleBatchSampler): | |||
| dist.set_distributed( | |||
| num_replicas=self.world_size, | |||
| rank=self.global_rank, | |||
| @@ -472,7 +472,7 @@ class TorchDDPDriver(TorchDriver): | |||
| raise RuntimeError("It is not allowed to use checkpoint retraining when you initialize ddp out of our " | |||
| "control.") | |||
| else: | |||
| if isinstance(dist, RandomBatchSampler): | |||
| if isinstance(dist, ReproducibleBatchSampler): | |||
| dist = re_instantiate_sampler(dist) | |||
| return replace_batch_sampler(dataloader, dist) | |||
| if isinstance(dist, ReproducibleSampler): | |||
| @@ -483,7 +483,7 @@ class TorchDDPDriver(TorchDriver): | |||
| elif dist == "dist": | |||
| args = self.get_dataloader_args(dataloader) | |||
| # 如果用户的 trainer.use_dist_sampler 为 True,那么此时其是否进行断点重训,不影响这里的行为; | |||
| if isinstance(args.batch_sampler, RandomBatchSampler): | |||
| if isinstance(args.batch_sampler, ReproducibleBatchSampler): | |||
| batch_sampler = re_instantiate_sampler(args.batch_sampler) | |||
| batch_sampler.set_distributed( | |||
| num_replicas=self.world_size, | |||
| @@ -13,7 +13,7 @@ __all__ = [ | |||
| from .torch_driver import TorchDriver | |||
| from fastNLP.core.drivers.torch_driver.utils import replace_sampler, replace_batch_sampler | |||
| from fastNLP.core.utils import auto_param_call | |||
| from fastNLP.core.samplers import RandomBatchSampler, ReproducibleSampler, re_instantiate_sampler | |||
| from fastNLP.core.samplers import ReproducibleBatchSampler, ReproducibleSampler, re_instantiate_sampler | |||
| from fastNLP.core.log import logger | |||
| @@ -129,18 +129,18 @@ class TorchSingleDriver(TorchDriver): | |||
| else: | |||
| return self._test_step(batch) | |||
| def set_dist_repro_dataloader(self, dataloader, dist: Union[str, RandomBatchSampler, ReproducibleSampler]=None, | |||
| def set_dist_repro_dataloader(self, dataloader, dist: Union[str, ReproducibleBatchSampler, ReproducibleSampler]=None, | |||
| reproducible: bool = False): | |||
| # 如果 dist 为 RandomBatchSampler, ReproducibleIterator 说明是在断点重训时 driver.load 函数调用; | |||
| if isinstance(dist, RandomBatchSampler): | |||
| # 如果 dist 为 ReproducibleBatchSampler, ReproducibleIterator 说明是在断点重训时 driver.load 函数调用; | |||
| if isinstance(dist, ReproducibleBatchSampler): | |||
| return replace_batch_sampler(dataloader, dist) | |||
| elif isinstance(dist, ReproducibleSampler): | |||
| return replace_sampler(dataloader, dist) | |||
| # 如果 dist 为 str 或者 None,说明是在 trainer 初试化时调用; | |||
| args = self.get_dataloader_args(dataloader) | |||
| if isinstance(args.batch_sampler, RandomBatchSampler): | |||
| if isinstance(args.batch_sampler, ReproducibleBatchSampler): | |||
| batch_sampler = re_instantiate_sampler(args.batch_sampler) | |||
| return replace_batch_sampler(dataloader, batch_sampler) | |||
| elif isinstance(args.sampler, ReproducibleSampler): | |||
| @@ -148,7 +148,7 @@ class TorchSingleDriver(TorchDriver): | |||
| return replace_sampler(dataloader, sampler) | |||
| if reproducible: | |||
| batch_sampler = RandomBatchSampler( | |||
| batch_sampler = ReproducibleBatchSampler( | |||
| batch_sampler=args.batch_sampler, | |||
| batch_size=args.batch_size, | |||
| drop_last=args.drop_last | |||
| @@ -30,7 +30,7 @@ from fastNLP.core.utils import apply_to_collection, torch_move_data_to_device | |||
| from fastNLP.envs import rank_zero_call | |||
| from fastNLP.envs import FASTNLP_SEED_WORKERS, FASTNLP_GLOBAL_RANK, FASTNLP_MODEL_FILENAME, FASTNLP_CHECKPOINT_FILENAME | |||
| from fastNLP.core.log import logger | |||
| from fastNLP.core.samplers import RandomBatchSampler, ReproducibleIterator | |||
| from fastNLP.core.samplers import ReproducibleBatchSampler, ReproducibleSampler | |||
| class TorchDriver(Driver): | |||
| @@ -183,9 +183,9 @@ class TorchDriver(Driver): | |||
| # 1. sampler 的状态,因为我们支持 resume training,即精确恢复到具体的一个 batch; | |||
| # 首先 pytorch 的 DataLoader 一定会有 sampler;另一方面,我们在断点重训的时候一定会在 `set_` 中将 dataloader 的 | |||
| # sampler 替换为 `ReproducibleSampler`;否则就是在单卡情况下将 batch_sampler 替换为 `RandomBatchSampler`; | |||
| # sampler 替换为 `ReproducibleSampler`;否则就是在单卡情况下将 batch_sampler 替换为 `ReproducibleBatchSampler`; | |||
| dataloader_args = self.get_dataloader_args(dataloader) | |||
| if isinstance(dataloader_args.batch_sampler, RandomBatchSampler): | |||
| if isinstance(dataloader_args.batch_sampler, ReproducibleBatchSampler): | |||
| sampler = dataloader_args.batch_sampler | |||
| elif dataloader_args.sampler: | |||
| sampler = dataloader_args.sampler | |||
| @@ -245,15 +245,14 @@ class TorchDriver(Driver): | |||
| # 3. 恢复 sampler 的状态; | |||
| dataloader_args = self.get_dataloader_args(dataloader) | |||
| if isinstance(dataloader_args.batch_sampler, RandomBatchSampler): | |||
| if isinstance(dataloader_args.batch_sampler, ReproducibleBatchSampler): | |||
| sampler = dataloader_args.batch_sampler | |||
| elif isinstance(dataloader_args.sampler, ReproducibleIterator): | |||
| elif isinstance(dataloader_args.sampler, ReproducibleSampler): | |||
| sampler = dataloader_args.sampler | |||
| elif self.is_distributed(): | |||
| raise RuntimeError("It is not allowed to use checkpoint retraining when you do not use our " | |||
| "`RandomBatchSampler` or `ReproducibleIterator`.") | |||
| raise RuntimeError("It is not allowed to use checkpoint retraining when you do not use our or `ReproducibleSampler`.") | |||
| else: | |||
| sampler = RandomBatchSampler( | |||
| sampler = ReproducibleBatchSampler( | |||
| batch_sampler=dataloader_args.batch_sampler if dataloader_args.batch_sampler is not None else dataloader_args.sampler, | |||
| batch_size=dataloader_args.batch_size, | |||
| drop_last=dataloader_args.drop_last | |||
| @@ -263,7 +262,7 @@ class TorchDriver(Driver): | |||
| # 4. 修改 trainer_state.batch_idx_in_epoch | |||
| # sampler 是类似 RandomSampler 的sampler,不是 batch_sampler; | |||
| if not isinstance(sampler, RandomBatchSampler): | |||
| if not isinstance(sampler, ReproducibleBatchSampler): | |||
| if dataloader_args.drop_last: | |||
| batch_idx_in_epoch = len( | |||
| sampler) // dataloader_args.batch_size - sampler.num_left_samples // dataloader_args.batch_size | |||
| @@ -19,6 +19,10 @@ __all__ = [ | |||
| "UnrepeatedSortedSampler", | |||
| "UnrepeatedSequentialSampler", | |||
| "RandomBatchSampler", | |||
| "BucketedBatchSampler", | |||
| "ReproducibleBatchSampler", | |||
| "re_instantiate_sampler", | |||
| "conversion_between_reproducible_and_unrepeated_sampler" | |||
| ] | |||
| @@ -28,5 +32,5 @@ from .unrepeated_sampler import UnrepeatedSampler, UnrepeatedRandomSampler, Unre | |||
| from .mix_sampler import MixSampler, DopedSampler, MixSequentialSampler, PollingSampler | |||
| from .reproducible_sampler import ReproducibleSampler, RandomSampler, SequentialSampler, SortedSampler | |||
| from .utils import re_instantiate_sampler, conversion_between_reproducible_and_unrepeated_sampler | |||
| from .reproducible_batch_sampler import RandomBatchSampler, BucketedBatchSampler | |||
| from .reproducible_batch_sampler import RandomBatchSampler, BucketedBatchSampler, ReproducibleBatchSampler | |||
| @@ -17,6 +17,9 @@ from abc import abstractmethod | |||
| class ReproducibleBatchSampler: | |||
| def __init__(self, **kwargs): | |||
| pass | |||
| @abstractmethod | |||
| def set_distributed(self, num_replicas, rank, pad=True): | |||
| raise NotImplementedError("Each specific batch_sampler should implement its own `set_distributed` method.") | |||
| @@ -41,6 +44,10 @@ class ReproducibleBatchSampler: | |||
| def set_epoch(self, epoch): | |||
| pass | |||
| @property | |||
| def batch_idx_in_epoch(self): | |||
| raise NotImplementedError("Each specific batch_sampler should implement its own `batch_idx_in_epoch` property.") | |||
| class RandomBatchSampler(ReproducibleBatchSampler): | |||
| # 这两个参数的值应当交给 driver 的 get_dataloader_args 函数去拿; | |||
| @@ -54,6 +61,8 @@ class RandomBatchSampler(ReproducibleBatchSampler): | |||
| :param drop_last: 如果最后一个 batch 无法构成 batch_size 那么多个 sample ,是否丢掉。 | |||
| :param kwargs: fastNLP 内部使用。 | |||
| """ | |||
| super().__init__() | |||
| self.batch_sampler = batch_sampler | |||
| self.batch_size = batch_size | |||
| self.drop_last = drop_last | |||