| @@ -19,7 +19,7 @@ from fastNLP.core.utils import ( | |||
| paddle_move_data_to_device, | |||
| is_in_paddle_dist, | |||
| ) | |||
| from fastNLP.core.samplers import ReproducibleIterator, RandomSampler, UnrepeatedDistributedSampler | |||
| from fastNLP.core.samplers import ReproducibleIterator, RandomSampler, UnrepeatedSampler | |||
| from fastNLP.envs.env import FASTNLP_DISTRIBUTED_CHECK, USER_CUDA_VISIBLE_DEVICES | |||
| from fastNLP.core.log import logger | |||
| @@ -362,7 +362,7 @@ class PaddleFleetDriver(PaddleDriver): | |||
| return dataloader | |||
| # evaluator | |||
| elif dist == "unrepeatdist": | |||
| sampler = UnrepeatedDistributedSampler( | |||
| sampler = UnrepeatedSampler( | |||
| dataset=dataloader.dataset, | |||
| shuffle=shuffle, | |||
| seed=int(os.environ.get("FASTNLP_SEED", 0)) | |||
| @@ -23,11 +23,12 @@ from fastNLP.core.drivers.torch_driver.utils import ( | |||
| ForwardState, | |||
| _MODE_PARAMETER, | |||
| reset_seed, | |||
| replace_sampler | |||
| replace_sampler, | |||
| replace_batch_sampler | |||
| ) | |||
| from fastNLP.core.drivers.utils import distributed_open_proc | |||
| from fastNLP.core.utils import auto_param_call, check_user_specific_params | |||
| from fastNLP.core.samplers import ReproducibleIterator, RandomSampler, UnrepeatedDistributedSampler | |||
| from fastNLP.core.samplers import ReproducibleIterator, RandomSampler, UnrepeatedSampler, ReproducibleBatchSampler | |||
| from fastNLP.envs import FASTNLP_DISTRIBUTED_CHECK, FASTNLP_GLOBAL_RANK, FASTNLP_GLOBAL_SEED | |||
| from fastNLP.core.log import logger | |||
| from fastNLP.core.drivers.torch_driver.dist_utils import fastnlp_torch_all_gather, fastnlp_torch_broadcast_object | |||
| @@ -445,11 +446,25 @@ class TorchDDPDriver(TorchDriver): | |||
| # return self.model(batch, **{_MODE_PARAMETER: ForwardState.TEST}) | |||
| return self._test_step(batch) | |||
| def set_dist_repro_dataloader(self, dataloader, dist: Optional[Union[str, ReproducibleIterator]], | |||
| reproducible: bool = False, sampler_or_batch_sampler=None): | |||
| def set_dist_repro_dataloader(self, dataloader, dist: Optional[Union[str, ReproducibleIterator, ReproducibleBatchSampler]]=None, | |||
| reproducible: bool = False): | |||
| if isinstance(dist, ReproducibleBatchSampler): | |||
| dist = re_instantiate_sampler(dist) | |||
| dist.set_distributed( | |||
| num_replicas=self.world_size, | |||
| rank=self.global_rank, | |||
| pad=True | |||
| ) | |||
| return replace_batch_sampler(dataloader, dist) | |||
| if isinstance(dist, ReproducibleIterator): | |||
| # 注意这里不需要调用 dist_sampler.set_distributed;因为如果用户使用的是 TorchDDPDriver,那么其在 Trainer 初始化的时候就已经调用了该函数; | |||
| dist = re_instantiate_sampler(dist) | |||
| dist.set_distributed( | |||
| num_replicas=self.world_size, | |||
| rank=self.global_rank, | |||
| pad=True | |||
| ) | |||
| return replace_sampler(dataloader, dist) | |||
| # trainer, evaluator | |||
| @@ -463,7 +478,15 @@ class TorchDDPDriver(TorchDriver): | |||
| elif dist == "dist": | |||
| args = self.get_dataloader_args(dataloader) | |||
| # 如果用户的 trainer.use_dist_sampler 为 True,那么此时其是否进行断点重训,不影响这里的行为; | |||
| if isinstance(args.sampler, ReproducibleIterator): | |||
| if isinstance(args.batch_sampler, ReproducibleBatchSampler): | |||
| batch_sampler = re_instantiate_sampler(args.batch_sampler) | |||
| batch_sampler.set_distributed( | |||
| num_replicas=self.world_size, | |||
| rank=self.global_rank, | |||
| pad=True | |||
| ) | |||
| return replace_batch_sampler(dataloader, batch_sampler) | |||
| elif isinstance(args.sampler, ReproducibleIterator): | |||
| sampler = re_instantiate_sampler(args.sampler) | |||
| sampler.set_distributed( | |||
| num_replicas=self.world_size, | |||
| @@ -477,7 +500,6 @@ class TorchDDPDriver(TorchDriver): | |||
| shuffle=args.shuffle, | |||
| seed=int(os.environ.get(FASTNLP_GLOBAL_SEED, 0)) | |||
| ) | |||
| # todo 这个你写个todo吧,有两个角度;第一个是dataloader即使检测到sampler是我们reproducible,也不能直接set_distributeds; 第二个如果是单卡的,也需要替换sampler乃至切换sampler的状态,方式之前多卡,现在切换成单卡运行 | |||
| sampler.set_distributed( | |||
| num_replicas=self.world_size, | |||
| rank=self.global_rank, | |||
| @@ -487,8 +509,11 @@ class TorchDDPDriver(TorchDriver): | |||
| # evaluator | |||
| elif dist == "unrepeatdist": | |||
| # todo @yh,补充 unrepeatdist 相关内容; | |||
| args = self.get_dataloader_args(dataloader) | |||
| sampler = UnrepeatedDistributedSampler( | |||
| # todo 判断 batch_sampler; | |||
| sampler = UnrepeatedSampler( | |||
| dataset=args.dataset, | |||
| shuffle=args.shuffle, | |||
| ) | |||
| @@ -133,8 +133,10 @@ class TorchSingleDriver(TorchDriver): | |||
| def set_dist_repro_dataloader(self, dataloader, dist: Union[str, ReproducibleBatchSampler, ReproducibleIterator]=None, | |||
| reproducible: bool = False): | |||
| if isinstance(dist, ReproducibleBatchSampler): | |||
| dist = re_instantiate_sampler(dist) | |||
| return replace_batch_sampler(dataloader, dist) | |||
| elif isinstance(dist, ReproducibleIterator): | |||
| dist = re_instantiate_sampler(dist) | |||
| return replace_sampler(dataloader, dist) | |||
| if reproducible: | |||
| @@ -244,8 +244,34 @@ class TorchDriver(Driver): | |||
| logger.debug("Load model.") | |||
| # 3. 恢复 sampler 的状态; | |||
| """ | |||
| 使用场景: | |||
| 现在sampler/batch_sampler的替换情况: | |||
| 1. 单卡多卡; | |||
| 2. 是否断点重训; | |||
| 3. 用户通过 dist 传入; | |||
| 4. 用户自己直接在外面替换dataloader的sampler或者 batchsampler; | |||
| 应当确定的规则: | |||
| batchsampler 优先级高于 sampler; | |||
| 单卡: | |||
| 不是断点重训: | |||
| 用户自己 | |||
| 用户不自己在外面直接替换 sampler 或者 batchsampler | |||
| 1. 单卡: | |||
| """ | |||
| dataloader_args = self.get_dataloader_args(dataloader) | |||
| # todo 先捋一下; | |||
| # batch_sampler = dataloader_args.batch_sampler | |||
| # if not (hasattr(batch_sampler, 'load_state_dict') and callable(batch_sampler.load_state_dict)): | |||
| sampler = dataloader_args.sampler | |||
| if not (hasattr(sampler, 'load_state_dict') and callable(sampler.load_state_dict)): | |||
| # 说明这里需要使用 ReproduceSampler 来弄一下了 | |||
| @@ -3,19 +3,24 @@ __all__ = [ | |||
| 'SortedSampler', | |||
| 'ConstTokenNumSampler', | |||
| 'ConstantTokenNumSampler', | |||
| 'UnrepeatedDistributedSampler', | |||
| 'MixSampler', | |||
| 'InnerSampler', | |||
| 'DopedSampler', | |||
| 'MixSequentialSampler', | |||
| 'PollingSampler', | |||
| 'ReproducibleIterator', | |||
| 'RandomSampler', | |||
| 're_instantiate_sampler' | |||
| 're_instantiate_sampler', | |||
| 'UnrepeatedSampler', | |||
| "UnrepeatedSortedSampler" | |||
| ] | |||
| from .sampler import BucketSampler, SortedSampler, ConstTokenNumSampler, ConstantTokenNumSampler, UnrepeatedDistributedSampler | |||
| from .mix_sampler import MixSampler, InnerSampler, DopedSampler, MixSequentialSampler, PollingSampler | |||
| from .sampler import BucketSampler, SortedSampler, ConstTokenNumSampler, ConstantTokenNumSampler | |||
| from .unrepeated_sampler import UnrepeatedSampler, UnrepeatedSortedSampler | |||
| from .mix_sampler import MixSampler, DopedSampler, MixSequentialSampler, PollingSampler | |||
| from .reproducible_sampler import ReproducibleIterator, RandomSampler, re_instantiate_sampler | |||
| from .reproducible_batch_sampler import ReproducibleBatchSampler, BucketedBatchSampler | |||
| @@ -4,7 +4,6 @@ from typing import Union, List, Iterable, Dict | |||
| __all__ = [ | |||
| 'MixSampler', | |||
| 'InnerSampler', | |||
| 'DopedSampler', | |||
| 'MixSequentialSampler', | |||
| 'PollingSampler' | |||
| @@ -16,7 +16,6 @@ def re_instantiate_sampler(sampler): | |||
| return type(sampler)(**all_attributes) | |||
| class ReproducibleIterator: | |||
| """ | |||
| 注意所有继承 `ReproducibleIterator` 的类的 `__init__` 方法中都需要加入参数 `**kwargs`,用来使我们再断点重训时重新实例化这个 sampler | |||
| @@ -7,7 +7,6 @@ __all__ = [ | |||
| "SortedSampler", | |||
| 'ConstTokenNumSampler', | |||
| "ConstantTokenNumSampler", | |||
| "UnrepeatedDistributedSampler", | |||
| ] | |||
| from itertools import chain | |||
| @@ -18,7 +17,7 @@ import numpy as np | |||
| from fastNLP.envs.imports import _NEED_IMPORT_TORCH | |||
| if _NEED_IMPORT_TORCH: | |||
| from torch.utils.data import SequentialSampler, Sampler, RandomSampler | |||
| from torch.utils.data import Sampler | |||
| else: | |||
| from fastNLP.core.utils.dummy_class import DummyClass as Sampler | |||
| @@ -727,87 +726,3 @@ def k_means_bucketing(lengths, buckets): | |||
| if buckets[bucket_id] is None or lengths[idx] <= buckets[bucket_id]: | |||
| bucket_data[bucket_id].append(idx) | |||
| return bucket_data | |||
| class UnrepeatedDistributedSampler: | |||
| def __init__(self, dataset, shuffle: bool = False, seed: int = 0): | |||
| """ | |||
| 考虑在多卡evaluate的场景下,不能重复sample。 | |||
| :param dataset: | |||
| :param shuffle: | |||
| :param seed: | |||
| """ | |||
| self.dataset = dataset | |||
| self.shuffle = shuffle | |||
| self.seed = seed | |||
| # 多卡的相关的参数 | |||
| self.num_replicas = 1 | |||
| self.rank = 0 | |||
| self.epoch = -1 | |||
| def __len__(self): | |||
| """ | |||
| 返回 sampler 一次完整的迭代过程会产生多少个index。多卡的情况下,只考虑当前rank; | |||
| :return: | |||
| """ | |||
| num_common = len(self.dataset)//self.num_replicas | |||
| self.num_samples = num_common + int(self.rank < (len(self.dataset)-num_common*self.num_replicas)) | |||
| return self.num_samples | |||
| def __iter__(self): | |||
| r""" | |||
| 当前使用num_consumed_samples做法会在交替使用的时候遇到问题; | |||
| Example: | |||
| >>> sampler = RandomSampler() | |||
| >>> iter1 = iter(sampler) | |||
| >>> iter2 = iter(sampler) | |||
| >>> next(iter1) | |||
| >>> next(iter2) # 当前num_consumed_samples的数量会发生变化 | |||
| """ | |||
| indices = self.generate_indices() | |||
| # subsample | |||
| indices = indices[self.rank:len(indices):self.num_replicas] | |||
| assert len(indices) == len(self) | |||
| for index in indices: | |||
| yield index | |||
| def generate_indices(self) -> List[int]: | |||
| """ | |||
| 生成随机序列 | |||
| :return: | |||
| """ | |||
| if self.shuffle: | |||
| indices = list(range(len(self.dataset))) | |||
| seed = self.seed + self.epoch | |||
| rng = np.random.default_rng(abs(seed)) | |||
| rng.shuffle(indices) | |||
| if self.epoch < 0: # 防止用户忘记调用 set_epoch,至少这样可以保证每次epoch出来的index顺序不同。 | |||
| self.epoch -= 1 | |||
| else: | |||
| indices = list(range(len(self.dataset))) | |||
| return indices | |||
| def set_epoch(self, epoch: int) -> None: | |||
| self.epoch = epoch | |||
| def set_distributed(self, num_replicas, rank): | |||
| """ | |||
| 该方法本质上等同于 ddp 情形下的没有完成的初始化,应当在初始化该 sampler 本身后立即被调用; | |||
| :param num_replicas: | |||
| :param rank: | |||
| :return: | |||
| """ | |||
| assert num_replicas>0 and isinstance(num_replicas, int) | |||
| assert isinstance(rank, int) and 0<=rank<num_replicas | |||
| # 注意初始化该函数时,所有的状态都应当默认是一个 epoch 刚开始训练的状态; | |||
| self.num_replicas = num_replicas | |||
| self.rank = rank | |||
| return self | |||
| @@ -0,0 +1,114 @@ | |||
| __all__ = [ | |||
| 'UnrepeatedSortedSampler', | |||
| 'UnrepeatedSampler' | |||
| ] | |||
| from typing import List, Union | |||
| from fastNLP.core.dataset import DataSet | |||
| import numpy as np | |||
| class UnrepeatedSampler: | |||
| def __init__(self, dataset, shuffle: bool = False, seed: int = 0, **kwargs): | |||
| """ | |||
| 考虑在多卡evaluate的场景下,不能重复sample。 | |||
| :param dataset: | |||
| :param shuffle: | |||
| :param seed: | |||
| """ | |||
| self.dataset = dataset | |||
| self.shuffle = shuffle | |||
| self.seed = seed | |||
| # 多卡的相关的参数 | |||
| self.num_replicas = kwargs.get('num_replicas', 1) | |||
| self.rank = kwargs.get('rank', 0) | |||
| self.epoch = kwargs.get('epoch', -1) | |||
| def __len__(self): | |||
| """ | |||
| 返回 sampler 一次完整的迭代过程会产生多少个index。多卡的情况下,只考虑当前rank; | |||
| :return: | |||
| """ | |||
| num_common = len(self.dataset)//self.num_replicas | |||
| self.num_samples = num_common + int(self.rank < (len(self.dataset)-num_common*self.num_replicas)) | |||
| return self.num_samples | |||
| def __iter__(self): | |||
| indices = self.generate_indices() | |||
| # subsample | |||
| indices = indices[self.rank:len(indices):self.num_replicas] | |||
| assert len(indices) == len(self) | |||
| for index in indices: | |||
| yield index | |||
| def generate_indices(self) -> List[int]: | |||
| """ | |||
| 生成随机序列 | |||
| :return: | |||
| """ | |||
| if self.shuffle: | |||
| indices = list(range(len(self.dataset))) | |||
| seed = self.seed + self.epoch | |||
| rng = np.random.default_rng(abs(seed)) | |||
| rng.shuffle(indices) | |||
| if self.epoch < 0: # 防止用户忘记调用 set_epoch,至少这样可以保证每次epoch出来的index顺序不同。 | |||
| self.epoch -= 1 | |||
| else: | |||
| indices = list(range(len(self.dataset))) | |||
| return indices | |||
| def set_epoch(self, epoch: int) -> None: | |||
| self.epoch = epoch | |||
| def set_distributed(self, num_replicas, rank): | |||
| """ | |||
| 该方法本质上等同于 ddp 情形下的没有完成的初始化,应当在初始化该 sampler 本身后立即被调用; | |||
| :param num_replicas: | |||
| :param rank: | |||
| :return: | |||
| """ | |||
| assert num_replicas>0 and isinstance(num_replicas, int) | |||
| assert isinstance(rank, int) and 0<=rank<num_replicas | |||
| # 注意初始化该函数时,所有的状态都应当默认是一个 epoch 刚开始训练的状态; | |||
| self.num_replicas = num_replicas | |||
| self.rank = rank | |||
| return self | |||
| class UnrepeatedSortedSampler(UnrepeatedSampler): | |||
| def __init__(self, dataset, length:Union[str, List], seed: int = 0): | |||
| """ | |||
| 将 dataset 中的数据根据 length 从长到短进行迭代,并且保证在多卡场景下数据不重复。本 sampler 可能导致各个机器上的 | |||
| batch 数量不完全一致。 | |||
| :param dataset: 实现了 __len__ 方法的数据容器。 | |||
| :param length: 如果为 List,应当与 dataset 有一样的长度,表示 dataset 中每个元素的数量;仅当传入的 dataset 为 fastNLP 的 | |||
| DataSet 时支持传入 str,会将该str理解为 dataset 的 field 名称,若 field 中的元素为 int,则认为该值是 sample 的长度。 | |||
| :param shuffle: 如果为 True,将不进行 shuffle,实际上数据会以从长到短的方式输出。 | |||
| :param seed: 设置的随机数种子 | |||
| :param kwargs: fastNLP 保留使用 | |||
| """ | |||
| super().__init__(dataset=dataset, shuffle=False, seed=seed) | |||
| if isinstance(dataset, DataSet): | |||
| length = dataset.get_field(length) | |||
| if not isinstance(length[0], int): | |||
| length = list(map(len, length)) | |||
| else: | |||
| assert len(length) == len(dataset), "When the dataset is not fastNLP.DataSet, " \ | |||
| "the length parameter can only be List[int]" | |||
| assert len(length) == len(dataset), "The length of `data` and `length` should be equal." | |||
| self.length = np.array(length, dtype=int) # 按照长到短排列的序号。 | |||
| self.sorted_indices = np.argsort(self.length)[::-1].tolist() # 按长度从高到低排序的 | |||
| def generate_indices(self) -> List[int]: | |||
| return self.sorted_indices | |||
| @@ -0,0 +1,64 @@ | |||
| from itertools import chain | |||
| import pytest | |||
| from fastNLP.core.samplers import UnrepeatedSampler, UnrepeatedSortedSampler | |||
| class DatasetWithVaryLength: | |||
| def __init__(self, num_of_data=100): | |||
| self.data = list(range(num_of_data)) | |||
| def __getitem__(self, item): | |||
| return self.data[item] | |||
| def __len__(self): | |||
| return len(self.data) | |||
| class TestUnrepeatedSampler: | |||
| @pytest.mark.parametrize('shuffle', [True, False]) | |||
| def test_single(self, shuffle): | |||
| num_of_data = 100 | |||
| data = DatasetWithVaryLength(num_of_data) | |||
| sampler = UnrepeatedSampler(data, shuffle) | |||
| indexes = set(sampler) | |||
| assert indexes==set(range(num_of_data)) | |||
| @pytest.mark.parametrize('num_replica', [2, 3]) | |||
| @pytest.mark.parametrize('num_of_data', [2, 3, 4, 100]) | |||
| @pytest.mark.parametrize('shuffle', [False, True]) | |||
| def test_multi(self, num_replica, num_of_data, shuffle): | |||
| data = DatasetWithVaryLength(num_of_data=num_of_data) | |||
| samplers = [] | |||
| for i in range(num_replica): | |||
| sampler = UnrepeatedSampler(dataset=data, shuffle=shuffle) | |||
| sampler.set_distributed(num_replica, rank=i) | |||
| samplers.append(sampler) | |||
| indexes = set(chain(*samplers)) | |||
| assert indexes==set(range(num_of_data)) | |||
| class TestUnrepeatedSortedSampler: | |||
| @pytest.mark.parametrize('shuffle', [True, False]) | |||
| def test_single(self, shuffle): | |||
| num_of_data = 100 | |||
| data = DatasetWithVaryLength(num_of_data) | |||
| sampler = UnrepeatedSortedSampler(data, length=data.data) | |||
| indexes = list(sampler) | |||
| assert indexes==list(range(num_of_data-1, -1, -1)) | |||
| @pytest.mark.parametrize('num_replica', [2, 3]) | |||
| @pytest.mark.parametrize('num_of_data', [2, 3, 4, 100]) | |||
| @pytest.mark.parametrize('shuffle', [False, True]) | |||
| def test_multi(self, num_replica, num_of_data, shuffle): | |||
| data = DatasetWithVaryLength(num_of_data=num_of_data) | |||
| samplers = [] | |||
| for i in range(num_replica): | |||
| sampler = UnrepeatedSortedSampler(dataset=data, length=data.data) | |||
| sampler.set_distributed(num_replica, rank=i) | |||
| samplers.append(sampler) | |||
| indexes = set(chain(*samplers)) | |||
| assert indexes==set(range(num_of_data)) | |||