| @@ -219,6 +219,7 @@ class Evaluator: | |||
| def remove_progress_bar(self, dataloader_name): | |||
| if self.progress_bar == 'rich' and hasattr(self, '_rich_task_id'): | |||
| f_rich_progress.destroy_task(self._rich_task_id) | |||
| f_rich_progress.refresh() # 使得最终的bar可以消失 | |||
| delattr(self, '_rich_task_id') | |||
| elif self.progress_bar == 'raw': | |||
| desc = 'Evaluation ends' | |||
| @@ -229,6 +230,7 @@ class Evaluator: | |||
| def finally_progress_bar(self): | |||
| if self.progress_bar == 'rich' and hasattr(self, '_rich_task_id'): | |||
| f_rich_progress.destroy_task(self._rich_task_id) | |||
| f_rich_progress.refresh() | |||
| delattr(self, '_rich_task_id') | |||
| @property | |||
| @@ -23,7 +23,6 @@ from fastNLP.core.drivers import Driver | |||
| from fastNLP.core.drivers.utils import choose_driver | |||
| from fastNLP.core.utils import check_fn_not_empty_params, get_fn_arg_names, match_and_substitute_params, nullcontext | |||
| from fastNLP.envs import rank_zero_call | |||
| from fastNLP.core.samplers import ReproducibleIterator, ReproducibleBatchSampler | |||
| from fastNLP.core.log import logger | |||
| from fastNLP.envs import FASTNLP_MODEL_FILENAME | |||
| @@ -610,7 +609,7 @@ class Trainer(TrainerEventTrigger): | |||
| r""" | |||
| 用于断点重训的加载函数; | |||
| 注意在 fastNLP 中断点重训的保存和加载逻辑是分开的,因此可能存在一种情况:用户只希望加载一个断点重训的状态,而在之后不再进行断点重训的 | |||
| 保存;在这种情况下,dataloader 的 sampler 就不一定会被替换成我们的 ReproducibleIterator; | |||
| 保存;在这种情况下,dataloader 的 sampler 就不一定会被替换成我们的 ReproducibleSampler; | |||
| 注意我们目前不支持单卡到多卡的断点重训; | |||
| @@ -24,6 +24,7 @@ class _FDataSet: | |||
| 对Dataset的封装,主要是修改dataset的__getitem__函数,增加返回下标idx,值得注意的是dataset需要实现__getattribute__函数才能在_FDataset | |||
| 中调用dataset的方法 | |||
| """ | |||
| def __init__(self, dataset) -> None: | |||
| self.dataset = dataset | |||
| @@ -45,6 +46,7 @@ class TorchDataLoader(DataLoader): | |||
| 提供给使用pytorch框架的DataLoader函数,若是配套使用FastNLP的dataset则可以自动使用AutoCollate函数对数据进行自动padding操作,用户也可以通过 | |||
| 提供的方法调节设置collate_fn的若干参数。 | |||
| """ | |||
| def __init__(self, dataset, batch_size: int = 1, | |||
| shuffle: bool = False, sampler: Optional["Sampler[int]"] = None, | |||
| batch_sampler: Optional["Sampler[Sequence[int]]"] = None, | |||
| @@ -175,17 +177,17 @@ class TorchDataLoader(DataLoader): | |||
| def prepare_torch_dataloader(ds_or_db: Union[DataSet, DataBundle, Sequence[DataSet], Mapping[str, DataSet]], | |||
| batch_size: int = 1, | |||
| shuffle: bool = False, sampler: Optional["Sampler[int]"] = None, | |||
| batch_sampler: Optional["Sampler[Sequence[int]]"] = None, | |||
| num_workers: int = 0, collate_fn: Optional[Callable] = None, | |||
| pin_memory: bool = False, drop_last: bool = False, | |||
| timeout: float = 0, worker_init_fn: Optional[Callable] = None, | |||
| multiprocessing_context=None, generator=None, prefetch_factor: int = 2, | |||
| persistent_workers: bool = False, non_train_sampler: Optional["Sampler[int]"] = None, | |||
| non_train_batch_size: int = 16, as_numpy: bool = False, | |||
| input_fields: Union[List, str] = None)\ | |||
| -> Union[TorchDataLoader, Dict[str, TorchDataLoader], Sequence[TorchDataLoader]]: | |||
| batch_size: int = 1, | |||
| shuffle: bool = False, sampler: Optional["Sampler[int]"] = None, | |||
| batch_sampler: Optional["Sampler[Sequence[int]]"] = None, | |||
| num_workers: int = 0, collate_fn: Optional[Callable] = None, | |||
| pin_memory: bool = False, drop_last: bool = False, | |||
| timeout: float = 0, worker_init_fn: Optional[Callable] = None, | |||
| multiprocessing_context=None, generator=None, prefetch_factor: int = 2, | |||
| persistent_workers: bool = False, non_train_sampler: Optional["Sampler[int]"] = None, | |||
| non_train_batch_size: int = 16, as_numpy: bool = False, | |||
| input_fields: Union[List, str, None] = None) \ | |||
| -> Union[TorchDataLoader, Dict[str, TorchDataLoader], Sequence[TorchDataLoader]]: | |||
| """ | |||
| 传入dataset或者data_bundle后,将其处理返回相对应的FdataLoader实例化对象 | |||
| @@ -221,7 +223,8 @@ def prepare_torch_dataloader(ds_or_db: Union[DataSet, DataBundle, Sequence[DataS | |||
| multiprocessing_context=multiprocessing_context, generator=generator, | |||
| prefetch_factor=prefetch_factor, persistent_workers=persistent_workers, | |||
| as_numpy=as_numpy) | |||
| dl.set_input(*input_fields) | |||
| if input_fields: | |||
| dl.set_input(*input_fields) | |||
| return dl | |||
| elif isinstance(ds_or_db, DataBundle): | |||
| @@ -233,17 +236,21 @@ def prepare_torch_dataloader(ds_or_db: Union[DataSet, DataBundle, Sequence[DataS | |||
| num_workers=num_workers, collate_fn=collate_fn, pin_memory=pin_memory, | |||
| drop_last=drop_last, timeout=timeout, worker_init_fn=worker_init_fn, | |||
| multiprocessing_context=multiprocessing_context, generator=generator, | |||
| prefetch_factor=prefetch_factor, persistent_workers=persistent_workers, | |||
| prefetch_factor=prefetch_factor, | |||
| persistent_workers=persistent_workers, | |||
| as_numpy=as_numpy) | |||
| else: | |||
| dl_bundle[name] = TorchDataLoader(dataset=ds, batch_size=non_train_batch_size, | |||
| shuffle=shuffle, sampler=non_train_sampler, batch_sampler=batch_sampler, | |||
| shuffle=shuffle, sampler=non_train_sampler, | |||
| batch_sampler=batch_sampler, | |||
| num_workers=num_workers, collate_fn=collate_fn, pin_memory=pin_memory, | |||
| drop_last=drop_last, timeout=timeout, worker_init_fn=worker_init_fn, | |||
| multiprocessing_context=multiprocessing_context, generator=generator, | |||
| prefetch_factor=prefetch_factor, persistent_workers=persistent_workers, | |||
| prefetch_factor=prefetch_factor, | |||
| persistent_workers=persistent_workers, | |||
| as_numpy=as_numpy) | |||
| dl_bundle[name].set_input(*input_fields) | |||
| if input_fields: | |||
| dl_bundle[name].set_input(*input_fields) | |||
| return dl_bundle | |||
| elif isinstance(ds_or_db, Sequence): | |||
| @@ -269,8 +276,9 @@ def prepare_torch_dataloader(ds_or_db: Union[DataSet, DataBundle, Sequence[DataS | |||
| prefetch_factor=prefetch_factor, persistent_workers=persistent_workers, | |||
| as_numpy=as_numpy) | |||
| ) | |||
| for dl in dl_bundle: | |||
| dl.set_input(*input_fields) | |||
| if input_fields: | |||
| for dl in dl_bundle: | |||
| dl.set_input(*input_fields) | |||
| return dl_bundle | |||
| elif isinstance(ds_or_db, Mapping): | |||
| @@ -282,18 +290,22 @@ def prepare_torch_dataloader(ds_or_db: Union[DataSet, DataBundle, Sequence[DataS | |||
| num_workers=num_workers, collate_fn=collate_fn, pin_memory=pin_memory, | |||
| drop_last=drop_last, timeout=timeout, worker_init_fn=worker_init_fn, | |||
| multiprocessing_context=multiprocessing_context, generator=generator, | |||
| prefetch_factor=prefetch_factor, persistent_workers=persistent_workers, | |||
| prefetch_factor=prefetch_factor, | |||
| persistent_workers=persistent_workers, | |||
| as_numpy=as_numpy) | |||
| else: | |||
| dl_bundle[name] = TorchDataLoader(dataset=ds, batch_size=non_train_batch_size, | |||
| shuffle=shuffle, sampler=non_train_sampler, batch_sampler=batch_sampler, | |||
| shuffle=shuffle, sampler=non_train_sampler, | |||
| batch_sampler=batch_sampler, | |||
| num_workers=num_workers, collate_fn=collate_fn, pin_memory=pin_memory, | |||
| drop_last=drop_last, timeout=timeout, worker_init_fn=worker_init_fn, | |||
| multiprocessing_context=multiprocessing_context, generator=generator, | |||
| prefetch_factor=prefetch_factor, persistent_workers=persistent_workers, | |||
| prefetch_factor=prefetch_factor, | |||
| persistent_workers=persistent_workers, | |||
| as_numpy=as_numpy) | |||
| dl_bundle[name].set_input(*input_fields) | |||
| if input_fields: | |||
| dl_bundle[name].set_input(*input_fields) | |||
| return dl_bundle | |||
| else: | |||
| @@ -49,13 +49,13 @@ class Driver(ABC): | |||
| 不同 gpu 上出现重复;为 'unrepeatdist' 时,表示该 dataloader 应该保证所有 gpu 上迭代出来的数据合并起来应该刚好等于原始的 | |||
| 数据,允许不同 gpu 上 batch 的数量不一致。其中 trainer 中 kwargs 的参数 `use_dist_sampler` 为 True 时,该值为 "dist"; | |||
| 否则为 None ,evaluator 中的 kwargs 的参数 `use_dist_sampler` 为 True 时,该值为 "unrepeatdist",否则为 None; | |||
| 注意当 dist 为 ReproducibleIterator, ReproducibleBatchSampler 时,是断点重训加载时 driver.load 函数在调用; | |||
| 注意当 dist 为 ReproducibleSampler, ReproducibleBatchSampler 时,是断点重训加载时 driver.load 函数在调用; | |||
| 当 dist 为 str 或者 None 时,是 trainer 在初始化时调用该函数; | |||
| :param reproducible: 如果为 False ,不要做任何考虑;如果为 True ,需要保证返回的 dataloader 可以保存当前的迭代状态,使得 | |||
| 可以可以加载。 | |||
| :return: 应当返回一个被替换 sampler 后的新的 dataloader 对象 (注意此处一定需要返回一个新的 dataloader 对象) ;此外, | |||
| 如果传入的 dataloader 中是 ReproducibleIterator 或者 ReproducibleBatchSampler 需要重新初始化一个放入返回的 | |||
| 如果传入的 dataloader 中是 ReproducibleSampler 或者 ReproducibleBatchSampler 需要重新初始化一个放入返回的 | |||
| dataloader 中。如果 dist 为空,且 reproducible 为 False,可直接返回原对象。 | |||
| """ | |||
| if dist is None and reproducible is False: | |||
| @@ -3,7 +3,7 @@ from typing import Optional, Union | |||
| from .jittor_driver import JittorDriver | |||
| from fastNLP.envs.imports import _NEED_IMPORT_JITTOR | |||
| from fastNLP.core.samplers import ReproducibleIterator | |||
| from fastNLP.core.samplers import ReproducibleSampler | |||
| if _NEED_IMPORT_JITTOR: | |||
| import jittor | |||
| @@ -70,7 +70,7 @@ class JittorMPIDriver(JittorDriver): | |||
| def test_step(self, batch): | |||
| return self._test_step(batch) | |||
| def set_dist_repro_dataloader(self, dataloader, dist: Optional[Union[str, ReproducibleIterator]], | |||
| def set_dist_repro_dataloader(self, dataloader, dist: Optional[Union[str, ReproducibleSampler]], | |||
| reproducible: bool = False, sampler_or_batch_sampler=None): | |||
| pass | |||
| @@ -3,7 +3,7 @@ from typing import Dict, Union | |||
| from .jittor_driver import JittorDriver | |||
| from fastNLP.core.utils import auto_param_call | |||
| from fastNLP.envs.imports import _NEED_IMPORT_JITTOR | |||
| from fastNLP.core.samplers import ReproducibleBatchSampler, ReproducibleIterator | |||
| from fastNLP.core.samplers import ReproducibleBatchSampler, ReproducibleSampler | |||
| if _NEED_IMPORT_JITTOR: | |||
| import jittor | |||
| @@ -99,25 +99,25 @@ class JittorSingleDriver(JittorDriver): | |||
| def is_distributed(self): | |||
| return False | |||
| def set_dist_repro_dataloader(self, dataloader, dist: Union[str, ReproducibleBatchSampler, ReproducibleIterator], | |||
| def set_dist_repro_dataloader(self, dataloader, dist: Union[str, ReproducibleBatchSampler, ReproducibleSampler], | |||
| reproducible: bool = False, sampler_or_batch_sampler=None): | |||
| # reproducible 的相关功能暂时没有实现 | |||
| if isinstance(dist, ReproducibleBatchSampler): | |||
| raise NotImplementedError | |||
| dataloader.batch_sampler = dist_sample | |||
| if isinstance(dist, ReproducibleIterator): | |||
| if isinstance(dist, ReproducibleSampler): | |||
| raise NotImplementedError | |||
| dataloader.batch_sampler.sampler = dist | |||
| if reproducible: | |||
| raise NotImplementedError | |||
| if isinstance(dataloader.batch_sampler.sampler, ReproducibleIterator): | |||
| if isinstance(dataloader.batch_sampler.sampler, ReproducibleSampler): | |||
| return dataloader | |||
| elif isinstance(dataloader.batch_sampler, ReproducibleBatchSampler): | |||
| elif isinstance(dataloader.batch_sampler, RandomBatchSampler): | |||
| return dataloader | |||
| else: | |||
| # TODO | |||
| batch_sampler = ReproducibleBatchSampler( | |||
| batch_sampler = RandomBatchSampler( | |||
| batch_sampler=dataloader.batch_sampler, | |||
| batch_size=dataloader.batch_sampler.batch_size, | |||
| drop_last=dataloader.drop_last | |||
| @@ -19,7 +19,7 @@ from fastNLP.core.utils import ( | |||
| paddle_move_data_to_device, | |||
| is_in_paddle_dist, | |||
| ) | |||
| from fastNLP.core.samplers import ReproducibleIterator, RandomSampler, UnrepeatedSampler | |||
| from fastNLP.core.samplers import ReproducibleSampler, RandomSampler, UnrepeatedRandomSampler | |||
| from fastNLP.envs.env import FASTNLP_DISTRIBUTED_CHECK, USER_CUDA_VISIBLE_DEVICES | |||
| from fastNLP.core.log import logger | |||
| @@ -312,13 +312,13 @@ class PaddleFleetDriver(PaddleDriver): | |||
| def test_step(self, batch): | |||
| return self._test_step(batch) | |||
| def set_dist_repro_dataloader(self, dataloader, dist: Optional[Union[str, ReproducibleIterator]], | |||
| def set_dist_repro_dataloader(self, dataloader, dist: Optional[Union[str, ReproducibleSampler]], | |||
| reproducible: bool = False, sampler_or_batch_sampler=None): | |||
| # 暂时不支持iterableDataset | |||
| assert dataloader.dataset_kind != _DatasetKind.ITER, \ | |||
| "FastNLP does not support `IteratorDataset` now." | |||
| if isinstance(dist, ReproducibleIterator): | |||
| if isinstance(dist, ReproducibleSampler): | |||
| dataloader.batch_sampler.sampler = dist | |||
| return dataloader | |||
| @@ -340,7 +340,7 @@ class PaddleFleetDriver(PaddleDriver): | |||
| # trainer | |||
| elif dist == "dist": | |||
| # 如果用户的 trainer.use_dist_sampler 为 True,那么此时其是否进行断点重训,不影响这里的行为; | |||
| if isinstance(dataloader.batch_sampler.sampler, ReproducibleIterator): | |||
| if isinstance(dataloader.batch_sampler.sampler, ReproducibleSampler): | |||
| dataloader.batch_sampler.sampler.set_distributed( | |||
| num_replicas=self.world_size, | |||
| rank=self.global_rank, | |||
| @@ -362,7 +362,7 @@ class PaddleFleetDriver(PaddleDriver): | |||
| return dataloader | |||
| # evaluator | |||
| elif dist == "unrepeatdist": | |||
| sampler = UnrepeatedSampler( | |||
| sampler = UnrepeatedRandomSampler( | |||
| dataset=dataloader.dataset, | |||
| shuffle=shuffle, | |||
| seed=int(os.environ.get("FASTNLP_SEED", 0)) | |||
| @@ -10,7 +10,7 @@ from fastNLP.core.utils import ( | |||
| get_paddle_device_id, | |||
| paddle_move_data_to_device, | |||
| ) | |||
| from fastNLP.core.samplers import ReproducibleBatchSampler, ReproducibleIterator | |||
| from fastNLP.core.samplers import ReproducibleBatchSampler, ReproducibleSampler | |||
| from fastNLP.core.log import logger | |||
| if _NEED_IMPORT_PADDLE: | |||
| @@ -139,7 +139,7 @@ class PaddleSingleDriver(PaddleDriver): | |||
| """ | |||
| return paddle_move_data_to_device(batch, "gpu:0") | |||
| def set_dist_repro_dataloader(self, dataloader, dist: Union[str, ReproducibleBatchSampler, ReproducibleIterator], | |||
| def set_dist_repro_dataloader(self, dataloader, dist: Union[str, ReproducibleBatchSampler, ReproducibleSampler], | |||
| reproducible: bool = False, sampler_or_batch_sampler=None): | |||
| # 暂时不支持IteratorDataset | |||
| assert dataloader.dataset_kind != _DatasetKind.ITER, \ | |||
| @@ -147,12 +147,12 @@ class PaddleSingleDriver(PaddleDriver): | |||
| if isinstance(dist, ReproducibleBatchSampler): | |||
| dataloader.batch_sampler = dist | |||
| return dataloader | |||
| if isinstance(dist, ReproducibleIterator): | |||
| if isinstance(dist, ReproducibleSampler): | |||
| dataloader.batch_sampler.sampler = dist | |||
| return dataloader | |||
| if reproducible: | |||
| if isinstance(dataloader.batch_sampler.sampler, ReproducibleIterator): | |||
| if isinstance(dataloader.batch_sampler.sampler, ReproducibleSampler): | |||
| return dataloader | |||
| elif isinstance(dataloader.batch_sampler, ReproducibleBatchSampler): | |||
| return dataloader | |||
| @@ -28,11 +28,11 @@ from fastNLP.core.drivers.torch_driver.utils import ( | |||
| ) | |||
| from fastNLP.core.drivers.utils import distributed_open_proc | |||
| from fastNLP.core.utils import auto_param_call, check_user_specific_params | |||
| from fastNLP.core.samplers import ReproducibleIterator, RandomSampler, UnrepeatedSampler, ReproducibleBatchSampler | |||
| from fastNLP.core.samplers import ReproducibleSampler, RandomSampler, UnrepeatedSequentialSampler, ReproducibleBatchSampler, \ | |||
| re_instantiate_sampler, UnrepeatedSampler, conversion_between_reproducible_and_unrepeated_sampler | |||
| from fastNLP.envs import FASTNLP_DISTRIBUTED_CHECK, FASTNLP_GLOBAL_RANK, FASTNLP_GLOBAL_SEED | |||
| from fastNLP.core.log import logger | |||
| from fastNLP.core.drivers.torch_driver.dist_utils import fastnlp_torch_all_gather, fastnlp_torch_broadcast_object | |||
| from fastNLP.core.samplers import re_instantiate_sampler | |||
| class TorchDDPDriver(TorchDriver): | |||
| @@ -446,13 +446,23 @@ class TorchDDPDriver(TorchDriver): | |||
| # return self.model(batch, **{_MODE_PARAMETER: ForwardState.TEST}) | |||
| return self._test_step(batch) | |||
| def set_dist_repro_dataloader(self, dataloader, dist: Optional[Union[str, ReproducibleIterator, ReproducibleBatchSampler]]=None, | |||
| def set_dist_repro_dataloader(self, dataloader, dist: Optional[Union[str, ReproducibleSampler, ReproducibleBatchSampler]]=None, | |||
| reproducible: bool = False): | |||
| # 如果 dist 为 ReproducibleBatchSampler, ReproducibleIterator 说明是在断点重训时 driver.load 函数调用; | |||
| # 如果 dist 为 ReproducibleBatchSampler, ReproducibleSampler 说明是在断点重训时 driver.load 函数调用; | |||
| # 注意这里不需要调用 dist_sampler.set_distributed;因为如果用户使用的是 TorchDDPDriver,那么其在 Trainer 初始化的时候就已经调用了该函数; | |||
| if isinstance(dist, ReproducibleBatchSampler): | |||
| dist.set_distributed( | |||
| num_replicas=self.world_size, | |||
| rank=self.global_rank, | |||
| pad=True | |||
| ) | |||
| return replace_batch_sampler(dataloader, dist) | |||
| if isinstance(dist, ReproducibleIterator): | |||
| if isinstance(dist, ReproducibleSampler): | |||
| dist.set_distributed( | |||
| num_replicas=self.world_size, | |||
| rank=self.global_rank, | |||
| pad=True | |||
| ) | |||
| return replace_sampler(dataloader, dist) | |||
| # 如果 dist 为 str 或者 None,说明是在 trainer 初试化时调用; | |||
| @@ -465,7 +475,7 @@ class TorchDDPDriver(TorchDriver): | |||
| if isinstance(dist, ReproducibleBatchSampler): | |||
| dist = re_instantiate_sampler(dist) | |||
| return replace_batch_sampler(dataloader, dist) | |||
| if isinstance(dist, ReproducibleIterator): | |||
| if isinstance(dist, ReproducibleSampler): | |||
| dist = re_instantiate_sampler(dist) | |||
| return replace_sampler(dataloader, dist) | |||
| return dataloader | |||
| @@ -481,7 +491,7 @@ class TorchDDPDriver(TorchDriver): | |||
| pad=True | |||
| ) | |||
| return replace_batch_sampler(dataloader, batch_sampler) | |||
| elif isinstance(args.sampler, ReproducibleIterator): | |||
| elif isinstance(args.sampler, ReproducibleSampler): | |||
| sampler = re_instantiate_sampler(args.sampler) | |||
| sampler.set_distributed( | |||
| num_replicas=self.world_size, | |||
| @@ -503,14 +513,15 @@ class TorchDDPDriver(TorchDriver): | |||
| return replace_sampler(dataloader, sampler) | |||
| # evaluator | |||
| elif dist == "unrepeatdist": | |||
| # todo @yh,补充 unrepeatdist 相关内容; | |||
| args = self.get_dataloader_args(dataloader) | |||
| # todo 判断 batch_sampler; | |||
| sampler = UnrepeatedSampler( | |||
| dataset=args.dataset, | |||
| shuffle=args.shuffle, | |||
| ) | |||
| if isinstance(args.sampler, ReproducibleSampler): | |||
| sampler = conversion_between_reproducible_and_unrepeated_sampler(args.sampler) | |||
| elif not isinstance(args.sampler, UnrepeatedSampler): | |||
| sampler = UnrepeatedSequentialSampler( | |||
| dataset=args.dataset | |||
| ) | |||
| else: | |||
| sampler = re_instantiate_sampler(args.sampler) | |||
| sampler.set_distributed( | |||
| num_replicas=self.world_size, | |||
| rank=self.global_rank | |||
| @@ -588,7 +599,7 @@ class TorchDDPDriver(TorchDriver): | |||
| :param group: | |||
| :return: | |||
| """ | |||
| return fastnlp_torch_all_gather(obj, device=self.data_device, group=group) | |||
| return fastnlp_torch_all_gather(obj, group=group) | |||
| def find_free_network_port() -> str: | |||
| @@ -1,11 +1,8 @@ | |||
| import io | |||
| import pickle | |||
| from typing import Mapping | |||
| _pickler = pickle.Pickler | |||
| _unpickler = pickle.Unpickler | |||
| from abc import ABC | |||
| from typing import Any, Union, List | |||
| import numpy as np | |||
| from typing import Any, List | |||
| from fastNLP.envs.imports import _TORCH_GREATER_EQUAL_1_8 | |||
| @@ -13,103 +10,25 @@ from fastNLP.envs.imports import _NEED_IMPORT_TORCH | |||
| if _NEED_IMPORT_TORCH: | |||
| import torch | |||
| from torch import distributed as dist | |||
| try: | |||
| from torch._C._distributed_c10d import ProcessGroupMPI | |||
| except ImportError: | |||
| _MPI_AVAILABLE = False | |||
| try: | |||
| from torch._C._distributed_c10d import ProcessGroupNCCL | |||
| except ImportError: | |||
| _NCCL_AVAILABLE = False | |||
| try: | |||
| from torch._C._distributed_c10d import ProcessGroupGloo | |||
| from torch._C._distributed_c10d import _ProcessGroupWrapper | |||
| except ImportError: | |||
| _GLOO_AVAILABLE = False | |||
| from fastNLP.core.utils import apply_to_collection | |||
| def all_gather_object(object_list, obj, group=None): | |||
| """ | |||
| Gathers picklable objects from the whole group into a list. Similar to | |||
| :func:`all_gather`, but Python objects can be passed in. Note that the object | |||
| must be picklable in order to be gathered. | |||
| Args: | |||
| object_list (list[Any]): Output list. It should be correctly sized as the | |||
| size of the group for this collective and will contain the output. | |||
| object (Any): Pickable Python object to be broadcast from current process. | |||
| group (ProcessGroup, optional): The process group to work on. If None, | |||
| the default process group will be used. Default is ``None``. | |||
| Returns: | |||
| None. If the calling rank is part of this group, the output of the | |||
| collective will be populated into the input ``object_list``. If the | |||
| calling rank is not part of the group, the passed in ``object_list`` will | |||
| be unmodified. | |||
| .. note:: Note that this API differs slightly from the :func:`all_gather` | |||
| collective since it does not provide an ``async_op`` handle and thus | |||
| will be a blocking call. | |||
| .. note:: For NCCL-based processed groups, internal tensor representations | |||
| of objects must be moved to the GPU device before communication takes | |||
| place. In this case, the device used is given by | |||
| ``torch.cuda.current_device()`` and it is the user's responsiblity to | |||
| ensure that this is set so that each rank has an individual GPU, via | |||
| ``torch.cuda.set_device()``. | |||
| .. warning:: | |||
| :func:`all_gather_object` uses ``pickle`` module implicitly, which is | |||
| known to be insecure. It is possible to construct malicious pickle data | |||
| which will execute arbitrary code during unpickling. Only call this | |||
| function with data you trust. | |||
| Example:: | |||
| >>> # Note: Process group initialization omitted on each rank. | |||
| >>> import torch.distributed as dist | |||
| >>> # Assumes world_size of 3. | |||
| >>> gather_objects = ["foo", 12, {1: 2}] # any picklable object | |||
| >>> output = [None for _ in gather_objects] | |||
| >>> dist.all_gather_object(output, gather_objects[dist.get_rank()]) | |||
| >>> output | |||
| ['foo', 12, {1: 2}] | |||
| """ | |||
| if dist.distributed_c10d._rank_not_in_group(group): | |||
| return | |||
| input_tensor, local_size = _object_to_tensor(obj) | |||
| current_device = torch.device("cpu") | |||
| if dist.is_nccl_available() and isinstance( | |||
| group or dist.distributed_c10d._get_default_group(), dist.ProcessGroupNCCL | |||
| ): | |||
| # See note about using torch.cuda.current_device() here in docstring. | |||
| # We cannot simply use my_rank since rank == device is not necessarily | |||
| # true. | |||
| current_device = torch.device("cuda", torch.cuda.current_device()) | |||
| input_tensor = input_tensor.to(current_device) | |||
| local_size = local_size.to(current_device) | |||
| # Gather all local sizes. This is so that we can find the max size, and index | |||
| # until the correct size when deserializing the tensors. | |||
| group_size = dist.get_world_size(group=group) | |||
| object_sizes_tensor = torch.zeros( | |||
| group_size, dtype=torch.long, device=current_device | |||
| ) | |||
| object_size_list = [ | |||
| object_sizes_tensor[i].unsqueeze(dim=0) for i in range(group_size) | |||
| ] | |||
| # Allgather tensor sizes | |||
| dist.all_gather(object_size_list, local_size, group=group) | |||
| max_object_size = int(max(object_size_list).item()) # type: ignore[type-var] | |||
| # Resize tensor to max size across all ranks. | |||
| input_tensor.resize_(max_object_size) | |||
| coalesced_output_tensor = torch.empty( | |||
| max_object_size * group_size, dtype=torch.uint8, device=current_device | |||
| ) | |||
| # Output tensors are nonoverlapping views of coalesced_output_tensor | |||
| output_tensors = [ | |||
| coalesced_output_tensor[max_object_size * i : max_object_size * (i + 1)] | |||
| for i in range(group_size) | |||
| ] | |||
| dist.all_gather(output_tensors, input_tensor, group=group) | |||
| # Deserialize outputs back to object. | |||
| for i, tensor in enumerate(output_tensors): | |||
| tensor = tensor.type(torch.uint8) | |||
| if tensor.device != torch.device("cpu"): | |||
| tensor = tensor.cpu() | |||
| tensor_size = object_size_list[i] | |||
| object_list[i] = _tensor_to_object(tensor, tensor_size) | |||
| def _validate_output_list_for_rank(my_rank, dst, gather_list): | |||
| if dst == my_rank: | |||
| if not gather_list: | |||
| @@ -123,8 +42,10 @@ def _validate_output_list_for_rank(my_rank, dst, gather_list): | |||
| ) | |||
| def gather_object(obj, object_gather_list=None, dst=0, group=None): | |||
| def fastnlp_torch_gather_object(obj, object_gather_list=None, dst=0, group=None): | |||
| """ | |||
| 从其它 rank gather 东西到 dst rank 。 | |||
| Gathers picklable objects from the whole group in a single process. | |||
| Similar to :func:`gather`, but Python objects can be passed in. Note that the | |||
| object must be picklable in order to be gathered. | |||
| @@ -176,6 +97,8 @@ def gather_object(obj, object_gather_list=None, dst=0, group=None): | |||
| # Ensure object_gather_list is specified appopriately. | |||
| my_rank = dist.get_rank() | |||
| _validate_output_list_for_rank(my_rank, dst, object_gather_list) | |||
| # 防止 unpickle 的时候出现在了发送的 gpu 上。 | |||
| obj = apply_to_collection(obj, torch.Tensor, _to_device, device=torch.device('cpu')) | |||
| input_tensor, local_size = _object_to_tensor(obj) | |||
| group_backend = dist.get_backend(group) | |||
| current_device = torch.device("cpu") | |||
| @@ -266,113 +189,11 @@ def send_recv_object(obj, src, cur_rank, device, group=None, tag=0): | |||
| return _tensor_to_object(tensor.cpu(), size) | |||
| def _all_gather(obj, **kwargs): | |||
| group = kwargs.get('group', None) | |||
| if isinstance(obj, torch.Tensor): | |||
| gathered_tensor = [torch.zeros_like(obj) for _ in | |||
| range(torch.distributed.get_world_size(group=group))] | |||
| torch.distributed.all_gather(gathered_tensor, obj, group=group) | |||
| return gathered_tensor | |||
| elif isinstance(obj, tuple) and isinstance(obj[1], torch.Tensor): | |||
| tensor, size = obj | |||
| # 首先需要同步 size 吧? | |||
| group_size = dist.get_world_size(group=group) | |||
| object_sizes_tensor = torch.zeros( | |||
| group_size, dtype=torch.long, device=tensor.device | |||
| ) | |||
| object_size_list = [ | |||
| object_sizes_tensor[i].unsqueeze(dim=0) for i in range(group_size) | |||
| ] | |||
| dist.all_gather(object_size_list, size, group=group) | |||
| max_object_size = int(max(object_size_list).item()) # type: ignore[type-var] | |||
| # Resize tensor to max size across all ranks. | |||
| tensor.resize_(max_object_size) | |||
| coalesced_output_tensor = torch.empty( | |||
| max_object_size * group_size, dtype=torch.uint8, device=tensor.device | |||
| ) | |||
| # Output tensors are nonoverlapping views of coalesced_output_tensor | |||
| output_tensors = [ | |||
| coalesced_output_tensor[max_object_size * i: max_object_size * (i + 1)] | |||
| for i in range(group_size) | |||
| ] | |||
| dist.all_gather(output_tensors, tensor, group=group) | |||
| object_list = [] | |||
| for i, tensor in enumerate(output_tensors): | |||
| tensor = tensor.type(torch.uint8) | |||
| tensor_size = object_size_list[i] | |||
| object_list.append(_tensor_to_object(tensor, tensor_size)) | |||
| return object_list | |||
| elif isinstance(obj, tuple) and len(obj) == 2: | |||
| obj, _type = obj | |||
| gathered_tensor = [torch.zeros_like(obj) for _ in | |||
| range(torch.distributed.get_world_size(group=group))] | |||
| torch.distributed.all_gather(gathered_tensor, obj, group=group) | |||
| if _type == np.ndarray: | |||
| gathered_tensor = [t.detach().cpu().numpy() for t in gathered_tensor] | |||
| else: | |||
| gathered_tensor = [_type(t.item()) for t in gathered_tensor] | |||
| return gathered_tensor | |||
| else: | |||
| raise RuntimeError("Unsupported types to implement all_gather.") | |||
| class CanTransferDataType(ABC): | |||
| """ | |||
| 检测可以进行传输的对象。 | |||
| """ | |||
| @classmethod | |||
| def __subclasshook__(cls, subclass: Any) -> Union[bool, Any]: | |||
| if cls is CanTransferDataType: | |||
| if issubclass(subclass, Mapping): | |||
| return False | |||
| if subclass in (torch.Tensor, tuple, list, str, int, float, bool, np.ndarray): | |||
| return True | |||
| return False | |||
| return NotImplemented | |||
| def _tensorize(obj, device=None): | |||
| if isinstance(obj, torch.Tensor): | |||
| return obj | |||
| if isinstance(obj, bool): | |||
| return torch.tensor(obj, dtype=torch.uint8, device=device), bool | |||
| if isinstance(obj, float): | |||
| return torch.tensor(obj, dtype=torch.float, device=device), float | |||
| if isinstance(obj, int): | |||
| return torch.tensor(obj, dtype=torch.int, device=device), int | |||
| if isinstance(obj, np.ndarray): | |||
| return torch.from_numpy(obj), np.ndarray | |||
| return _object_to_tensor(obj, device) | |||
| def _to_device(tensor, device): | |||
| return tensor.contiguous().to(device) | |||
| def convert_to_tensors(data: Any, device=None) -> Any: | |||
| data = apply_to_collection(data, CanTransferDataType, _tensorize) | |||
| def _move_to_device_and_make_contiguous(t: Union[torch.Tensor, tuple], device: Union[str, torch.device]): | |||
| if isinstance(t, tuple): | |||
| if isinstance(t[1], torch.Tensor): # 说明是 object 转的 | |||
| return t[0].to(device).contiguous(), t[1].to(device) | |||
| else: # 说明第二个元素是type,见 to_dtype_tensor 函数 | |||
| return t[0].to(device).contiguous(), t[1] | |||
| return t.to(device).contiguous() | |||
| data = apply_to_collection(data, (torch.Tensor, tuple), _move_to_device_and_make_contiguous, device=device) | |||
| return data | |||
| def fastnlp_torch_all_gather(obj:Any, device=None, group=None)->List: | |||
| def fastnlp_torch_all_gather(obj: Any, device=None, group=None) ->List: | |||
| """ | |||
| 实现任何类型的数据都使用该接口可以进行 all_gather 操作。对于非 tensor 类型的数据,通过 pickle 序列化再反序列化的方式进行传输。 | |||
| @@ -390,36 +211,28 @@ def fastnlp_torch_all_gather(obj:Any, device=None, group=None)->List: | |||
| {'a': 1, 'b':[1, 2], 'c':{'d': 2}} | |||
| ] | |||
| :param obj: 任意结构的数据,所有的 value 都会变成 list ,其长度为 world_size ,依次为每个 rank 上的对象值 | |||
| :param device: 当前 rank 使用的 device 是哪个。为 None 的话默认使用 torch.cuda.current_device() 获取。 | |||
| :param obj: 任意结构的数据,如果为 tensor ,需要保证每个显卡上的 tensor 的形状是一样的。如果传入的是非 tensor 对象都将直接进行 | |||
| 序列化之后进行传输。 | |||
| :param device: 当前该参数无意义。 | |||
| :param group: | |||
| :return: 返回的结果是 [obj0, obj1, ...],其中 obj_i 即为第 i 个 rank 上的 obj 。 | |||
| """ | |||
| # # 首先将所有的都移动到cpu上并且连续,防止有 pickle 出问题 | |||
| # obj = apply_to_collection(obj, torch.Tensor, _to_device, device=torch.device('cpu')) | |||
| if device is None: | |||
| device = torch.cuda.current_device() | |||
| if _TORCH_GREATER_EQUAL_1_8: | |||
| if isinstance(obj, torch.Tensor): | |||
| objs = [torch.zeros_like(obj) for _ in range(dist.get_world_size(group))] | |||
| dist.all_gather(objs, obj, group=group) | |||
| else: | |||
| objs = [None for _ in range(dist.get_world_size(group))] | |||
| dist.all_gather_object(objs, obj) | |||
| objs = apply_to_collection(objs, torch.Tensor, _to_device, device=device) # 保证如果有tensor的话,所有tensor都在当前卡上 | |||
| return objs | |||
| group = group if group is not None else torch.distributed.group.WORLD | |||
| data = convert_to_tensors(obj, device=device) | |||
| data = apply_to_collection(data, (torch.Tensor, tuple), _all_gather, group=group) | |||
| objs = [] | |||
| def _get_obj_on_idx(obj, idx): | |||
| return obj[idx] | |||
| for i in range(dist.get_world_size(group)): | |||
| objs.append(apply_to_collection(data, dtype=list, function=_get_obj_on_idx, idx=i)) | |||
| # 防止 unpickle 的时候弄到发送的 gpu 上了 | |||
| obj = apply_to_collection(obj, torch.Tensor, _to_device, device=torch.device('cpu')) | |||
| if _TORCH_GREATER_EQUAL_1_8: | |||
| dist.all_gather_object(objs, obj, group=group) | |||
| else: | |||
| objs = all_gather_object(objs, obj, group=group) | |||
| return objs | |||
| def fastnlp_torch_broadcast_object(obj, src, device, group=None): | |||
| def fastnlp_torch_broadcast_object(obj, src, device=None, group=None): | |||
| """ | |||
| 将 src 上的 obj 对象广播到其它 rank 上。 | |||
| @@ -430,10 +243,9 @@ def fastnlp_torch_broadcast_object(obj, src, device, group=None): | |||
| :return: | |||
| """ | |||
| cur_rank = dist.get_rank(group) | |||
| # if cur_rank == src: | |||
| # # 如果有 tensor 全部移动到 cpu 上,方便 pickle | |||
| # obj = apply_to_collection(obj, torch.Tensor, _to_device, device=torch.device('cpu')) | |||
| if cur_rank == src: | |||
| # 如果有 tensor 全部移动到 cpu 上,方便 pickle , 不然 unpickle 的时候可能会 pickle 到发送过来的卡那里 | |||
| obj = apply_to_collection(obj, torch.Tensor, _to_device, device=torch.device('cpu')) | |||
| if _TORCH_GREATER_EQUAL_1_8: | |||
| if cur_rank!=src: | |||
| get_obj = [None] | |||
| @@ -442,6 +254,8 @@ def fastnlp_torch_broadcast_object(obj, src, device, group=None): | |||
| else: | |||
| dist.broadcast_object_list([obj], src=src, group=group) | |||
| return obj | |||
| if device is None: | |||
| device = torch.cuda.current_device() | |||
| if cur_rank == src: | |||
| tensor, size = _object_to_tensor(obj, device=device) | |||
| @@ -460,3 +274,107 @@ def fastnlp_torch_broadcast_object(obj, src, device, group=None): | |||
| return _tensor_to_object(tensor, tensor_size=size.item()) | |||
| def _check_for_nccl_backend(group): | |||
| pg = group or dist.distributed_c10d._get_default_group() | |||
| # It is not expected for PG to be wrapped many times, but support it just | |||
| # in case | |||
| while isinstance(pg, _ProcessGroupWrapper): | |||
| pg = pg.wrapped_pg | |||
| return ( | |||
| dist.is_nccl_available() and | |||
| isinstance(pg, dist.ProcessGroupNCCL) | |||
| ) | |||
| def all_gather_object(object_list, obj, group=None): | |||
| """ | |||
| 复制 pytorch 的代码,使得可以版本兼容低版本的 pytorch 。 | |||
| Gathers picklable objects from the whole group into a list. Similar to | |||
| :func:`all_gather`, but Python objects can be passed in. Note that the object | |||
| must be picklable in order to be gathered. | |||
| Args: | |||
| object_list (list[Any]): Output list. It should be correctly sized as the | |||
| size of the group for this collective and will contain the output. | |||
| object (Any): Pickable Python object to be broadcast from current process. | |||
| group (ProcessGroup, optional): The process group to work on. If None, | |||
| the default process group will be used. Default is ``None``. | |||
| Returns: | |||
| None. If the calling rank is part of this group, the output of the | |||
| collective will be populated into the input ``object_list``. If the | |||
| calling rank is not part of the group, the passed in ``object_list`` will | |||
| be unmodified. | |||
| .. note:: Note that this API differs slightly from the :func:`all_gather` | |||
| collective since it does not provide an ``async_op`` handle and thus | |||
| will be a blocking call. | |||
| .. note:: For NCCL-based processed groups, internal tensor representations | |||
| of objects must be moved to the GPU device before communication takes | |||
| place. In this case, the device used is given by | |||
| ``torch.cuda.current_device()`` and it is the user's responsiblity to | |||
| ensure that this is set so that each rank has an individual GPU, via | |||
| ``torch.cuda.set_device()``. | |||
| .. warning:: | |||
| :func:`all_gather_object` uses ``pickle`` module implicitly, which is | |||
| known to be insecure. It is possible to construct malicious pickle data | |||
| which will execute arbitrary code during unpickling. Only call this | |||
| function with data you trust. | |||
| Example:: | |||
| >>> # Note: Process group initialization omitted on each rank. | |||
| >>> import torch.distributed as dist | |||
| >>> # Assumes world_size of 3. | |||
| >>> gather_objects = ["foo", 12, {1: 2}] # any picklable object | |||
| >>> output = [None for _ in gather_objects] | |||
| >>> dist.all_gather_object(output, gather_objects[dist.get_rank()]) | |||
| >>> output | |||
| ['foo', 12, {1: 2}] | |||
| """ | |||
| if dist._rank_not_in_group(group): | |||
| return | |||
| input_tensor, local_size = _object_to_tensor(obj) | |||
| current_device = torch.device("cpu") | |||
| is_nccl_backend = _check_for_nccl_backend(group) | |||
| if is_nccl_backend: | |||
| # See note about using torch.cuda.current_device() here in docstring. | |||
| # We cannot simply use my_rank since rank == device is not necessarily | |||
| # true. | |||
| current_device = torch.device("cuda", torch.cuda.current_device()) | |||
| input_tensor = input_tensor.to(current_device) | |||
| local_size = local_size.to(current_device) | |||
| # Gather all local sizes. This is so that we can find the max size, and index | |||
| # until the correct size when deserializing the tensors. | |||
| group_size = dist.get_world_size(group=group) | |||
| object_sizes_tensor = torch.zeros( | |||
| group_size, dtype=torch.long, device=current_device | |||
| ) | |||
| object_size_list = [ | |||
| object_sizes_tensor[i].unsqueeze(dim=0) for i in range(group_size) | |||
| ] | |||
| # Allgather tensor sizes | |||
| dist.all_gather(object_size_list, local_size, group=group) | |||
| max_object_size = int(max(object_size_list).item()) # type: ignore[type-var] | |||
| # Resize tensor to max size across all ranks. | |||
| input_tensor.resize_(max_object_size) | |||
| coalesced_output_tensor = torch.empty( | |||
| max_object_size * group_size, dtype=torch.uint8, device=current_device | |||
| ) | |||
| # Output tensors are nonoverlapping views of coalesced_output_tensor | |||
| output_tensors = [ | |||
| coalesced_output_tensor[max_object_size * i : max_object_size * (i + 1)] | |||
| for i in range(group_size) | |||
| ] | |||
| dist.all_gather(output_tensors, input_tensor, group=group) | |||
| # Deserialize outputs back to object. | |||
| for i, tensor in enumerate(output_tensors): | |||
| tensor = tensor.type(torch.uint8) | |||
| if tensor.device != torch.device("cpu"): | |||
| tensor = tensor.cpu() | |||
| tensor_size = object_size_list[i] | |||
| object_list[i] = _tensor_to_object(tensor, tensor_size) | |||
| @@ -13,9 +13,8 @@ __all__ = [ | |||
| from .torch_driver import TorchDriver | |||
| from fastNLP.core.drivers.torch_driver.utils import replace_sampler, replace_batch_sampler | |||
| from fastNLP.core.utils import auto_param_call | |||
| from fastNLP.core.samplers import ReproducibleBatchSampler, ReproducibleIterator | |||
| from fastNLP.core.samplers import ReproducibleBatchSampler, ReproducibleSampler, re_instantiate_sampler | |||
| from fastNLP.core.log import logger | |||
| from fastNLP.core.samplers import re_instantiate_sampler | |||
| class TorchSingleDriver(TorchDriver): | |||
| @@ -130,13 +129,13 @@ class TorchSingleDriver(TorchDriver): | |||
| else: | |||
| return self._test_step(batch) | |||
| def set_dist_repro_dataloader(self, dataloader, dist: Union[str, ReproducibleBatchSampler, ReproducibleIterator]=None, | |||
| def set_dist_repro_dataloader(self, dataloader, dist: Union[str, ReproducibleBatchSampler, ReproducibleSampler]=None, | |||
| reproducible: bool = False): | |||
| # 如果 dist 为 ReproducibleBatchSampler, ReproducibleIterator 说明是在断点重训时 driver.load 函数调用; | |||
| if isinstance(dist, ReproducibleBatchSampler): | |||
| return replace_batch_sampler(dataloader, dist) | |||
| elif isinstance(dist, ReproducibleIterator): | |||
| elif isinstance(dist, ReproducibleSampler): | |||
| return replace_sampler(dataloader, dist) | |||
| # 如果 dist 为 str 或者 None,说明是在 trainer 初试化时调用; | |||
| @@ -144,7 +143,7 @@ class TorchSingleDriver(TorchDriver): | |||
| if isinstance(args.batch_sampler, ReproducibleBatchSampler): | |||
| batch_sampler = re_instantiate_sampler(args.batch_sampler) | |||
| return replace_batch_sampler(dataloader, batch_sampler) | |||
| elif isinstance(args.sampler, ReproducibleIterator): | |||
| elif isinstance(args.sampler, ReproducibleSampler): | |||
| sampler = re_instantiate_sampler(args.sampler) | |||
| return replace_sampler(dataloader, sampler) | |||
| @@ -30,7 +30,7 @@ from fastNLP.core.utils import apply_to_collection, torch_move_data_to_device | |||
| from fastNLP.envs import rank_zero_call | |||
| from fastNLP.envs import FASTNLP_SEED_WORKERS, FASTNLP_GLOBAL_RANK, FASTNLP_MODEL_FILENAME, FASTNLP_CHECKPOINT_FILENAME | |||
| from fastNLP.core.log import logger | |||
| from fastNLP.core.samplers import ReproducibleBatchSampler, ReproducibleIterator | |||
| from fastNLP.core.samplers import ReproducibleBatchSampler, ReproducibleSampler | |||
| class TorchDriver(Driver): | |||
| @@ -182,8 +182,8 @@ class TorchDriver(Driver): | |||
| # trainer.dataloader 来改变 dataloader 的状态,从而适配训练或者评测环境; | |||
| # 1. sampler 的状态,因为我们支持 resume training,即精确恢复到具体的一个 batch; | |||
| # 首先 pytorch 的 DataLoader 一定会有 sampler;另一方面,我们在断点重训的时候一定会在 `replace_sampler` 中将 dataloader 的 | |||
| # sampler 替换为 `ReproducibleIterator`;否则就是在单卡情况下将 batch_sampler 替换为 `ReproducibleBatchSampler`; | |||
| # 首先 pytorch 的 DataLoader 一定会有 sampler;另一方面,我们在断点重训的时候一定会在 `set_` 中将 dataloader 的 | |||
| # sampler 替换为 `ReproducibleSampler`;否则就是在单卡情况下将 batch_sampler 替换为 `ReproducibleBatchSampler`; | |||
| dataloader_args = self.get_dataloader_args(dataloader) | |||
| if isinstance(dataloader_args.batch_sampler, ReproducibleBatchSampler): | |||
| sampler = dataloader_args.batch_sampler | |||
| @@ -247,11 +247,10 @@ class TorchDriver(Driver): | |||
| dataloader_args = self.get_dataloader_args(dataloader) | |||
| if isinstance(dataloader_args.batch_sampler, ReproducibleBatchSampler): | |||
| sampler = dataloader_args.batch_sampler | |||
| elif isinstance(dataloader_args.sampler, ReproducibleIterator): | |||
| elif isinstance(dataloader_args.sampler, ReproducibleSampler): | |||
| sampler = dataloader_args.sampler | |||
| elif self.is_distributed(): | |||
| raise RuntimeError("It is not allowed to use checkpoint retraining when you do not use our " | |||
| "`ReproducibleBatchSampler` or `ReproducibleIterator`.") | |||
| raise RuntimeError("It is not allowed to use checkpoint retraining when you do not use our or `ReproducibleSampler`.") | |||
| else: | |||
| sampler = ReproducibleBatchSampler( | |||
| batch_sampler=dataloader_args.batch_sampler if dataloader_args.batch_sampler is not None else dataloader_args.sampler, | |||
| @@ -291,7 +290,7 @@ class TorchDriver(Driver): | |||
| @staticmethod | |||
| def worker_init_function(worker_id: int, rank: Optional[int] = None) -> None: # pragma: no cover | |||
| """The worker_init_fn that Lightning automatically adds to your dataloader if you previously set set the seed | |||
| """The worker_init_fn that Lightning automatically adds to your dataloader if you previously set the seed | |||
| with ``seed_everything(seed, workers=True)``. | |||
| See also the PyTorch documentation on | |||
| @@ -9,18 +9,28 @@ __all__ = [ | |||
| 'MixSequentialSampler', | |||
| 'PollingSampler', | |||
| 'ReproducibleIterator', | |||
| 'ReproducibleSampler', | |||
| 'RandomSampler', | |||
| 're_instantiate_sampler', | |||
| "SequentialSampler", | |||
| "SortedSampler", | |||
| 'UnrepeatedSampler', | |||
| "UnrepeatedSortedSampler" | |||
| 'UnrepeatedRandomSampler', | |||
| "UnrepeatedSortedSampler", | |||
| "UnrepeatedSequentialSampler", | |||
| "RandomBatchSampler", | |||
| "BucketedBatchSampler", | |||
| "ReproducibleBatchSampler", | |||
| "re_instantiate_sampler", | |||
| "conversion_between_reproducible_and_unrepeated_sampler" | |||
| ] | |||
| from .sampler import BucketSampler, SortedSampler, ConstTokenNumSampler, ConstantTokenNumSampler | |||
| from .unrepeated_sampler import UnrepeatedSampler, UnrepeatedSortedSampler | |||
| from .unrepeated_sampler import UnrepeatedSampler, UnrepeatedRandomSampler, UnrepeatedSortedSampler, UnrepeatedSequentialSampler | |||
| from .mix_sampler import MixSampler, DopedSampler, MixSequentialSampler, PollingSampler | |||
| from .reproducible_sampler import ReproducibleIterator, RandomSampler, re_instantiate_sampler | |||
| from .reproducible_batch_sampler import ReproducibleBatchSampler, BucketedBatchSampler | |||
| from .reproducible_sampler import ReproducibleSampler, RandomSampler, SequentialSampler, SortedSampler | |||
| from .utils import re_instantiate_sampler, conversion_between_reproducible_and_unrepeated_sampler | |||
| from .reproducible_batch_sampler import RandomBatchSampler, BucketedBatchSampler, ReproducibleBatchSampler | |||
| @@ -1,6 +1,6 @@ | |||
| __all__ = [ | |||
| 'BucketedBatchSampler', | |||
| "ReproducibleBatchSampler" | |||
| "RandomBatchSampler" | |||
| ] | |||
| import math | |||
| @@ -16,7 +16,10 @@ from fastNLP.core.log import logger | |||
| from abc import abstractmethod | |||
| class ReproducibleBatchIterator: | |||
| class ReproducibleBatchSampler: | |||
| def __init__(self, **kwargs): | |||
| pass | |||
| @abstractmethod | |||
| def set_distributed(self, num_replicas, rank, pad=True): | |||
| raise NotImplementedError("Each specific batch_sampler should implement its own `set_distributed` method.") | |||
| @@ -41,19 +44,25 @@ class ReproducibleBatchIterator: | |||
| def set_epoch(self, epoch): | |||
| pass | |||
| @property | |||
| def batch_idx_in_epoch(self): | |||
| raise NotImplementedError("Each specific batch_sampler should implement its own `batch_idx_in_epoch` property.") | |||
| class ReproducibleBatchSampler(ReproducibleBatchIterator): | |||
| class RandomBatchSampler(ReproducibleBatchSampler): | |||
| # 这两个参数的值应当交给 driver 的 get_dataloader_args 函数去拿; | |||
| def __init__(self, batch_sampler, batch_size: int, drop_last: bool, **kwargs): | |||
| """ | |||
| 可以使得 batch_sampler 对象状态恢复的 wrapper 。 | |||
| :param batch_sampler: 可迭代出 数字 或 数字列表 的可迭代对象。ReproducibleBatchSampler 将首先遍历一边该对象,然后将迭代 | |||
| :param batch_sampler: 可迭代出 数字 或 数字列表 的可迭代对象。RandomBatchSampler 将首先遍历一边该对象,然后将迭代 | |||
| 出来的序号暂存起来,使用时按照 batch_size 的 batch 大小吐出序号列表。 | |||
| :param batch_size: 每个 batch 的大小是多少。 | |||
| :param drop_last: 如果最后一个 batch 无法构成 batch_size 那么多个 sample ,是否丢掉。 | |||
| :param kwargs: fastNLP 内部使用。 | |||
| """ | |||
| super().__init__() | |||
| self.batch_sampler = batch_sampler | |||
| self.batch_size = batch_size | |||
| self.drop_last = drop_last | |||
| @@ -138,7 +147,7 @@ class ReproducibleBatchSampler(ReproducibleBatchIterator): | |||
| (len(self.index_list) - self.data_idx + self.batch_size - 1) // self.batch_size | |||
| class BucketedBatchSampler(ReproducibleBatchIterator): | |||
| class BucketedBatchSampler(ReproducibleBatchSampler): | |||
| def __init__(self, dataset, length: Union[List[int], str], batch_size:int = 32, num_batch_per_bucket:int = 10, | |||
| shuffle: bool = True, drop_last: bool = False, seed: int = 0, **kwargs): | |||
| """ | |||
| @@ -1,24 +1,21 @@ | |||
| from typing import Dict, List | |||
| from typing import Dict, List, Union | |||
| import math | |||
| import numpy as np | |||
| from fastNLP.core.log import logger | |||
| from fastNLP.core.dataset import DataSet | |||
| __all__ = [ | |||
| 'ReproducibleIterator', | |||
| 'ReproducibleSampler', | |||
| 'RandomSampler', | |||
| 're_instantiate_sampler' | |||
| "SortedSampler", | |||
| "SequentialSampler" | |||
| ] | |||
| def re_instantiate_sampler(sampler): | |||
| all_attributes = vars(sampler) | |||
| return type(sampler)(**all_attributes) | |||
| class ReproducibleIterator: | |||
| class ReproducibleSampler: | |||
| """ | |||
| 注意所有继承 `ReproducibleIterator` 的类的 `__init__` 方法中都需要加入参数 `**kwargs`,用来使我们再断点重训时重新实例化这个 sampler | |||
| 注意所有继承 `ReproducibleSampler` 的类的 `__init__` 方法中都需要加入参数 `**kwargs`,用来使我们再断点重训时重新实例化这个 sampler | |||
| 或者 batch_sampler;注意,所有在 init 中初始化的变量,都不能含有 _ 下横线作为开头;所有不在 init 中设置的变量都必须以下横线开头。 | |||
| """ | |||
| @@ -46,7 +43,7 @@ class ReproducibleIterator: | |||
| pass | |||
| class RandomSampler(ReproducibleIterator): | |||
| class RandomSampler(ReproducibleSampler): | |||
| def __init__(self, dataset, shuffle: bool = True, seed: int = 0, **kwargs): | |||
| """ | |||
| @@ -156,8 +153,8 @@ class RandomSampler(ReproducibleIterator): | |||
| f"we cannot use {self.__class__.__name__} to load it." | |||
| length = states['length'] | |||
| assert length == len(self.dataset), "The number of samples is different between the checkpoint record " \ | |||
| "and current dataset." | |||
| assert length == len(self.dataset), f"The number of samples is different between the checkpoint record({length}) " \ | |||
| f"and current dataset({len(self.dataset)})." | |||
| self.seed = states['seed'] | |||
| self.epoch = states['epoch'] | |||
| self.num_consumed_samples = states['num_consumed_samples'] | |||
| @@ -214,9 +211,132 @@ class RandomSampler(ReproducibleIterator): | |||
| self.pad else math.floor(((len(self.dataset) - num_consumed_samples) / self.num_replicas)) | |||
| class SequentialSampler(RandomSampler): | |||
| def __init__(self, dataset, dist_mode:str='interval', **kwargs): | |||
| """ | |||
| 按照顺序读取 dataset 。在多卡情况下,间隔读取,例如,在两卡情况下,卡0取 [0,2,4,..], 卡1取 [1,3,5...]。 | |||
| :param dataset: 实现了 __len__ 方法的数据容器。 | |||
| :param kwargs: | |||
| """ | |||
| super().__init__(dataset=dataset, shuffle=False, seed=0, **kwargs) | |||
| def __iter__(self): | |||
| if self.during_iter: # 如果发现_during_iter为True,说明之前的还没结束,只有强制重新初始化了 | |||
| self.num_consumed_samples = 0 | |||
| self.during_iter = True | |||
| indices = self.generate_indices() | |||
| if self.pad: | |||
| # add extra samples to make it evenly divisible | |||
| padding_size = self.total_size - len(indices) | |||
| if padding_size <= len(indices): | |||
| indices += indices[:padding_size] | |||
| else: | |||
| indices += (indices * math.ceil(padding_size / len(indices)))[:padding_size] | |||
| else: | |||
| # remove tail of data to make it evenly divisible. | |||
| indices = indices[:self.total_size] | |||
| assert len(indices) == self.total_size | |||
| # subsample | |||
| indices = indices[self.num_consumed_samples:] | |||
| indices = indices[self.rank:len(indices):self.num_replicas] | |||
| assert len(indices) == self.num_left_samples | |||
| for index in indices: | |||
| self.num_consumed_samples += self.num_replicas | |||
| yield index | |||
| self.during_iter = False | |||
| self.num_consumed_samples = 0 | |||
| def generate_indices(self) -> List[int]: | |||
| """ | |||
| 生成随机序列 | |||
| :return: | |||
| """ | |||
| return list(range(len(self.dataset))) | |||
| def state_dict(self) -> Dict: | |||
| states = { | |||
| 'num_consumed_samples': self.num_consumed_samples, # 注意该值是计算所有 rank 上训练的所有数据; | |||
| 'sampler_type': self.__class__.__name__, | |||
| 'length': len(self.dataset), | |||
| } | |||
| return states | |||
| def load_state_dict(self, states: Dict): | |||
| # 如果 self.during_iter 是 True,那么 data_idx 一定是 0; | |||
| assert self.during_iter is False, "Cannot call load_state_dict() when it is " \ | |||
| "during an unfinished iteration." | |||
| assert states['sampler_type'] == self.__class__.__name__, f"The sampler type in checkpoint is {states['sampler_type']}," \ | |||
| f"we cannot use {self.__class__.__name__} to load it." | |||
| length = states['length'] | |||
| assert length == len(self.dataset), f"The number of samples is different between the checkpoint record({length}) " \ | |||
| f"and current dataset({len(self.dataset)})." | |||
| self.num_consumed_samples = states['num_consumed_samples'] | |||
| if self.num_consumed_samples >= length: # 如果保存的时候已经到达了最后一个sample了,则直接将结果重置为0 | |||
| self.num_consumed_samples = 0 | |||
| class SortedSampler(SequentialSampler): | |||
| def __init__(self, dataset, length:Union[str, List], **kwargs): | |||
| """ | |||
| 将 dataset 中的数据根据 length 从长到短进行迭代。在多卡情况下,由于padding 最后一个 sample 可能是最长的那个 sample。 | |||
| :param dataset: 实现了 __len__ 方法的数据容器。 | |||
| :param length: 如果为 List,应当与 dataset 有一样的长度,表示 dataset 中每个元素的数量;仅当传入的 dataset 为 fastNLP 的 | |||
| DataSet 时支持传入 str,会将该str理解为 dataset 的 field 名称,若 field 中的元素为 int,则认为该值是 sample 的长度。 | |||
| :param seed: 设置的随机数种子 | |||
| :param kwargs: fastNLP 保留使用 | |||
| """ | |||
| super().__init__(dataset=dataset, **kwargs) | |||
| if isinstance(dataset, DataSet): | |||
| length = dataset.get_field(length) | |||
| if not isinstance(length[0], int): | |||
| length = list(map(len, length)) | |||
| else: | |||
| assert len(length) == len(dataset), "When the dataset is not fastNLP.DataSet, " \ | |||
| "the length parameter can only be List[int]" | |||
| assert len(length) == len(dataset), "The length of `data` and `length` should be equal." | |||
| self.length = np.array(length, dtype=int) # 按照长到短排列的序号。 | |||
| self.sorted_indices = np.argsort(self.length)[::-1].tolist() # 按长度从高到低排序的 | |||
| def generate_indices(self) -> List[int]: | |||
| return self.sorted_indices | |||
| def __iter__(self): | |||
| if self.during_iter: # 如果发现_during_iter为True,说明之前的还没结束,只有强制重新初始化了 | |||
| self.num_consumed_samples = 0 | |||
| self.during_iter = True | |||
| indices = self.generate_indices() | |||
| if self.pad: | |||
| padding_size = self.total_size - len(indices) | |||
| if padding_size <= len(indices): | |||
| indices += indices[:padding_size] | |||
| else: | |||
| indices += (indices * math.ceil(padding_size / len(indices)))[:padding_size] | |||
| else: | |||
| # remove tail of data to make it evenly divisible. | |||
| indices = indices[:self.total_size] | |||
| assert len(indices) == self.total_size | |||
| # subsample | |||
| indices = indices[self.num_consumed_samples:] | |||
| indices = indices[self.rank:len(indices):self.num_replicas] | |||
| assert len(indices) == self.num_left_samples | |||
| for index in indices: | |||
| self.num_consumed_samples += self.num_replicas | |||
| yield index | |||
| self.during_iter = False | |||
| self.num_consumed_samples = 0 | |||
| @@ -1,6 +1,8 @@ | |||
| __all__ = [ | |||
| 'UnrepeatedSampler', | |||
| 'UnrepeatedSortedSampler', | |||
| 'UnrepeatedSampler' | |||
| 'UnrepeatedRandomSampler', | |||
| "UnrepeatedSequentialSampler" | |||
| ] | |||
| from typing import List, Union | |||
| @@ -10,13 +12,21 @@ import numpy as np | |||
| class UnrepeatedSampler: | |||
| """ | |||
| 在多卡场景下保证 indice 不重复的 sampler | |||
| """ | |||
| pass | |||
| class UnrepeatedRandomSampler(UnrepeatedSampler): | |||
| def __init__(self, dataset, shuffle: bool = False, seed: int = 0, **kwargs): | |||
| """ | |||
| 考虑在多卡evaluate的场景下,不能重复sample。 | |||
| :param dataset: | |||
| :param shuffle: | |||
| :param seed: | |||
| :param dataset: 实现了 __len__ 方法的数据容器。 | |||
| :param shuffle: 如果为 True,将不进行 shuffle,实际上数据会以从长到短的方式输出。 | |||
| :param seed: 设置的随机数种子 | |||
| :param kwargs: fastNLP 保留使用 | |||
| """ | |||
| self.dataset = dataset | |||
| self.shuffle = shuffle | |||
| @@ -33,8 +43,8 @@ class UnrepeatedSampler: | |||
| :return: | |||
| """ | |||
| num_common = len(self.dataset)//self.num_replicas | |||
| self.num_samples = num_common + int(self.rank < (len(self.dataset)-num_common*self.num_replicas)) | |||
| return self.num_samples | |||
| num_samples = num_common + int(self.rank < (len(self.dataset)-num_common*self.num_replicas)) | |||
| return num_samples | |||
| def __iter__(self): | |||
| indices = self.generate_indices() | |||
| @@ -83,8 +93,8 @@ class UnrepeatedSampler: | |||
| return self | |||
| class UnrepeatedSortedSampler(UnrepeatedSampler): | |||
| def __init__(self, dataset, length:Union[str, List], seed: int = 0): | |||
| class UnrepeatedSortedSampler(UnrepeatedRandomSampler): | |||
| def __init__(self, dataset, length:Union[str, List], **kwargs): | |||
| """ | |||
| 将 dataset 中的数据根据 length 从长到短进行迭代,并且保证在多卡场景下数据不重复。本 sampler 可能导致各个机器上的 | |||
| batch 数量不完全一致。 | |||
| @@ -92,11 +102,9 @@ class UnrepeatedSortedSampler(UnrepeatedSampler): | |||
| :param dataset: 实现了 __len__ 方法的数据容器。 | |||
| :param length: 如果为 List,应当与 dataset 有一样的长度,表示 dataset 中每个元素的数量;仅当传入的 dataset 为 fastNLP 的 | |||
| DataSet 时支持传入 str,会将该str理解为 dataset 的 field 名称,若 field 中的元素为 int,则认为该值是 sample 的长度。 | |||
| :param shuffle: 如果为 True,将不进行 shuffle,实际上数据会以从长到短的方式输出。 | |||
| :param seed: 设置的随机数种子 | |||
| :param kwargs: fastNLP 保留使用 | |||
| """ | |||
| super().__init__(dataset=dataset, shuffle=False, seed=seed) | |||
| super().__init__(dataset=dataset, shuffle=False, seed=0, **kwargs) | |||
| if isinstance(dataset, DataSet): | |||
| length = dataset.get_field(length) | |||
| if not isinstance(length[0], int): | |||
| @@ -107,8 +115,29 @@ class UnrepeatedSortedSampler(UnrepeatedSampler): | |||
| assert len(length) == len(dataset), "The length of `data` and `length` should be equal." | |||
| self.length = np.array(length, dtype=int) # 按照长到短排列的序号。 | |||
| self.sorted_indices = np.argsort(self.length)[::-1].tolist() # 按长度从高到低排序的 | |||
| length = np.array(length, dtype=int) # 按照长到短排列的序号。 | |||
| self.sorted_indices = np.argsort(length)[::-1].tolist() # 按长度从高到低排序的 | |||
| def generate_indices(self) -> List[int]: | |||
| return self.sorted_indices | |||
| class UnrepeatedSequentialSampler(UnrepeatedRandomSampler): | |||
| def __init__(self, dataset, **kwargs): | |||
| """ | |||
| 按照顺序读取 dataset。在多卡情况下,间隔读取,例如,在两卡情况下,卡0取 [0,2,4,..], 卡1取 [1,3,5...]。 | |||
| :param dataset: 实现了 __len__ 方法的数据容器。 | |||
| :param kwargs: | |||
| """ | |||
| super(UnrepeatedSequentialSampler, self).__init__(dataset, shuffle=False, seed=0, **kwargs) | |||
| def __iter__(self): | |||
| indices = self.generate_indices() | |||
| indices = indices[self.rank:len(indices):self.num_replicas] | |||
| for index in indices: | |||
| yield index | |||
| def generate_indices(self) -> List[int]: | |||
| return list(range(len(self.dataset))) | |||
| @@ -0,0 +1,42 @@ | |||
| __all__ = [ | |||
| 're_instantiate_sampler', | |||
| 'conversion_between_reproducible_and_unrepeated_sampler' | |||
| ] | |||
| from fastNLP.core.samplers.unrepeated_sampler import * | |||
| from fastNLP.core.samplers.reproducible_sampler import * | |||
| def conversion_between_reproducible_and_unrepeated_sampler(sampler): | |||
| """ | |||
| 将 sampler 替换成其对应的 reproducible 版本或 unrepeated 版本。如果输入是 UnrepeatedSampler 但是没找到对应的 | |||
| ReproducibleSampler, | |||
| :param sampler: | |||
| :return: | |||
| """ | |||
| assert isinstance(sampler, UnrepeatedSampler) or isinstance(sampler, ReproducibleSampler), \ | |||
| "The sampler must be UnrepeatedSampler or ReproducibleSampler" | |||
| if isinstance(sampler, UnrepeatedSampler): | |||
| if isinstance(sampler, UnrepeatedRandomSampler): | |||
| return re_instantiate_sampler(sampler, new_sampler_class=RandomSampler) | |||
| elif isinstance(sampler, UnrepeatedSequentialSampler): | |||
| return re_instantiate_sampler(sampler, new_sampler_class=SequentialSampler) | |||
| elif isinstance(sampler, UnrepeatedSortedSampler): | |||
| return re_instantiate_sampler(sampler, new_sampler_class=SortedSampler) | |||
| raise TypeError(f"{sampler.__class__} has no unrepeated version.") | |||
| else: | |||
| if isinstance(sampler, RandomSampler): | |||
| return re_instantiate_sampler(sampler, new_sampler_class=UnrepeatedRandomSampler) | |||
| elif isinstance(sampler, SequentialSampler): | |||
| return re_instantiate_sampler(sampler, new_sampler_class=UnrepeatedSequentialSampler) | |||
| elif isinstance(sampler, SortedSampler): | |||
| return re_instantiate_sampler(sampler, new_sampler_class=UnrepeatedSortedSampler) | |||
| raise TypeError(f"{sampler.__class__} has no reproducible version.") | |||
| def re_instantiate_sampler(sampler, new_sampler_class=None): | |||
| all_attributes = vars(sampler) | |||
| if new_sampler_class is not None: | |||
| return new_sampler_class(**all_attributes) | |||
| return type(sampler)(**all_attributes) | |||
| @@ -94,9 +94,6 @@ class FRichProgress(Progress, metaclass=Singleton): | |||
| self.print = self.console.print | |||
| self.log = self.console.log | |||
| # start new | |||
| self.start() | |||
| self.console.show_cursor(show=True) | |||
| return self | |||
| def set_transient(self, transient: bool = True): | |||
| @@ -154,6 +151,7 @@ class FRichProgress(Progress, metaclass=Singleton): | |||
| super().start() | |||
| self.console.show_cursor(show=True) | |||
| if (sys.stdin and sys.stdin.isatty()) and get_global_rank() == 0: | |||
| f_rich_progress = FRichProgress().new_progess( | |||
| "[progress.description]{task.description}", | |||
| @@ -1,4 +1,4 @@ | |||
| import unittest | |||
| import pytest | |||
| from fastNLP.core.dataloaders.paddle_dataloader.fdl import PaddleDataLoader | |||
| from fastNLP.core.dataset import DataSet | |||
| @@ -17,7 +17,7 @@ class RandomDataset(Dataset): | |||
| return 10 | |||
| class TestPaddle(unittest.TestCase): | |||
| class TestPaddle: | |||
| def test_init(self): | |||
| # ds = DataSet({'x': [[1, 2], [2, 3, 4], [1]] * 10, 'y': [0, 1, 1] * 10}) | |||
| @@ -1,11 +1,11 @@ | |||
| import unittest | |||
| import pytest | |||
| from fastNLP.core.dataloaders.torch_dataloader import TorchDataLoader, prepare_torch_dataloader | |||
| from fastNLP.core.dataset import DataSet | |||
| from fastNLP.io.data_bundle import DataBundle | |||
| class TestFdl(unittest.TestCase): | |||
| class TestFdl: | |||
| def test_init_v1(self): | |||
| ds = DataSet({"x": [[1, 2], [2, 3, 4], [4, 5, 6, 7]] * 10, "y": [1, 0, 1] * 10}) | |||
| @@ -1,12 +1,12 @@ | |||
| import os | |||
| import unittest | |||
| import pytest | |||
| import numpy as np | |||
| from fastNLP.core.dataset import DataSet, FieldArray, Instance, ApplyResultException | |||
| class TestDataSetInit(unittest.TestCase): | |||
| class TestDataSetInit: | |||
| """初始化DataSet的办法有以下几种: | |||
| 1) 用dict: | |||
| 1.1) 二维list DataSet({"x": [[1, 2], [3, 4]]}) | |||
| @@ -24,46 +24,46 @@ class TestDataSetInit(unittest.TestCase): | |||
| def test_init_v1(self): | |||
| # 一维list | |||
| ds = DataSet([Instance(x=[1, 2, 3, 4], y=[5, 6])] * 40) | |||
| self.assertTrue("x" in ds.field_arrays and "y" in ds.field_arrays) | |||
| self.assertEqual(ds.field_arrays["x"].content, [[1, 2, 3, 4], ] * 40) | |||
| self.assertEqual(ds.field_arrays["y"].content, [[5, 6], ] * 40) | |||
| assert ("x" in ds.field_arrays and "y" in ds.field_arrays) == True | |||
| assert ds.field_arrays["x"].content == [[1, 2, 3, 4], ] * 40 | |||
| assert ds.field_arrays["y"].content == [[5, 6], ] * 40 | |||
| def test_init_v2(self): | |||
| # 用dict | |||
| ds = DataSet({"x": [[1, 2, 3, 4]] * 40, "y": [[5, 6]] * 40}) | |||
| self.assertTrue("x" in ds.field_arrays and "y" in ds.field_arrays) | |||
| self.assertEqual(ds.field_arrays["x"].content, [[1, 2, 3, 4], ] * 40) | |||
| self.assertEqual(ds.field_arrays["y"].content, [[5, 6], ] * 40) | |||
| assert ("x" in ds.field_arrays and "y" in ds.field_arrays) == True | |||
| assert ds.field_arrays["x"].content == [[1, 2, 3, 4], ] * 40 | |||
| assert ds.field_arrays["y"].content == [[5, 6], ] * 40 | |||
| def test_init_assert(self): | |||
| with self.assertRaises(AssertionError): | |||
| with pytest.raises(AssertionError): | |||
| _ = DataSet({"x": [[1, 2, 3, 4]] * 40, "y": [[5, 6]] * 100}) | |||
| with self.assertRaises(AssertionError): | |||
| with pytest.raises(AssertionError): | |||
| _ = DataSet([[1, 2, 3, 4]] * 10) | |||
| with self.assertRaises(ValueError): | |||
| with pytest.raises(ValueError): | |||
| _ = DataSet(0.00001) | |||
| class TestDataSetMethods(unittest.TestCase): | |||
| class TestDataSetMethods: | |||
| def test_append(self): | |||
| dd = DataSet() | |||
| for _ in range(3): | |||
| dd.append(Instance(x=[1, 2, 3, 4], y=[5, 6])) | |||
| self.assertEqual(len(dd), 3) | |||
| self.assertEqual(dd.field_arrays["x"].content, [[1, 2, 3, 4]] * 3) | |||
| self.assertEqual(dd.field_arrays["y"].content, [[5, 6]] * 3) | |||
| assert len(dd) == 3 | |||
| assert dd.field_arrays["x"].content == [[1, 2, 3, 4]] * 3 | |||
| assert dd.field_arrays["y"].content == [[5, 6]] * 3 | |||
| def test_add_field(self): | |||
| dd = DataSet() | |||
| dd.add_field("x", [[1, 2, 3]] * 10) | |||
| dd.add_field("y", [[1, 2, 3, 4]] * 10) | |||
| dd.add_field("z", [[5, 6]] * 10) | |||
| self.assertEqual(len(dd), 10) | |||
| self.assertEqual(dd.field_arrays["x"].content, [[1, 2, 3]] * 10) | |||
| self.assertEqual(dd.field_arrays["y"].content, [[1, 2, 3, 4]] * 10) | |||
| self.assertEqual(dd.field_arrays["z"].content, [[5, 6]] * 10) | |||
| assert len(dd) == 10 | |||
| assert dd.field_arrays["x"].content == [[1, 2, 3]] * 10 | |||
| assert dd.field_arrays["y"].content == [[1, 2, 3, 4]] * 10 | |||
| assert dd.field_arrays["z"].content == [[5, 6]] * 10 | |||
| with self.assertRaises(RuntimeError): | |||
| with pytest.raises(RuntimeError): | |||
| dd.add_field("??", [[1, 2]] * 40) | |||
| def test_delete_field(self): | |||
| @@ -71,8 +71,8 @@ class TestDataSetMethods(unittest.TestCase): | |||
| dd.add_field("x", [[1, 2, 3]] * 10) | |||
| dd.add_field("y", [[1, 2, 3, 4]] * 10) | |||
| dd.delete_field("x") | |||
| self.assertFalse("x" in dd.field_arrays) | |||
| self.assertTrue("y" in dd.field_arrays) | |||
| assert ("x" in dd.field_arrays) == False | |||
| assert "y" in dd.field_arrays | |||
| def test_delete_instance(self): | |||
| dd = DataSet() | |||
| @@ -80,30 +80,30 @@ class TestDataSetMethods(unittest.TestCase): | |||
| dd.add_field("x", [[1, 2, 3]] * old_length) | |||
| dd.add_field("y", [[1, 2, 3, 4]] * old_length) | |||
| dd.delete_instance(0) | |||
| self.assertEqual(len(dd), old_length - 1) | |||
| assert len(dd) == old_length - 1 | |||
| dd.delete_instance(0) | |||
| self.assertEqual(len(dd), old_length - 2) | |||
| assert len(dd) == old_length - 2 | |||
| def test_getitem(self): | |||
| ds = DataSet({"x": [[1, 2, 3, 4]] * 40, "y": [[5, 6]] * 40}) | |||
| ins_1, ins_0 = ds[0], ds[1] | |||
| self.assertTrue(isinstance(ins_1, Instance) and isinstance(ins_0, Instance)) | |||
| self.assertEqual(ins_1["x"], [1, 2, 3, 4]) | |||
| self.assertEqual(ins_1["y"], [5, 6]) | |||
| self.assertEqual(ins_0["x"], [1, 2, 3, 4]) | |||
| self.assertEqual(ins_0["y"], [5, 6]) | |||
| assert isinstance(ins_1, Instance) and isinstance(ins_0, Instance) == True | |||
| assert ins_1["x"] == [1, 2, 3, 4] | |||
| assert ins_1["y"] == [5, 6] | |||
| assert ins_0["x"] == [1, 2, 3, 4] | |||
| assert ins_0["y"] == [5, 6] | |||
| sub_ds = ds[:10] | |||
| self.assertTrue(isinstance(sub_ds, DataSet)) | |||
| self.assertEqual(len(sub_ds), 10) | |||
| assert isinstance(sub_ds, DataSet) == True | |||
| assert len(sub_ds) == 10 | |||
| sub_ds_1 = ds[[10, 0, 2, 3]] | |||
| self.assertTrue(isinstance(sub_ds_1, DataSet)) | |||
| self.assertEqual(len(sub_ds_1), 4) | |||
| assert isinstance(sub_ds_1, DataSet) == True | |||
| assert len(sub_ds_1) == 4 | |||
| field_array = ds['x'] | |||
| self.assertTrue(isinstance(field_array, FieldArray)) | |||
| self.assertEqual(len(field_array), 40) | |||
| assert isinstance(field_array, FieldArray) == True | |||
| assert len(field_array) == 40 | |||
| def test_setitem(self): | |||
| ds = DataSet({"x": [[1, 2, 3, 4]] * 40, "y": [[5, 6]] * 40}) | |||
| @@ -120,73 +120,73 @@ class TestDataSetMethods(unittest.TestCase): | |||
| assert ds[2]['x'] == ins1['x'] and ds[2]['y'] == ins1['y'] | |||
| def test_get_item_error(self): | |||
| with self.assertRaises(RuntimeError): | |||
| with pytest.raises(RuntimeError): | |||
| ds = DataSet({"x": [[1, 2, 3, 4]] * 10, "y": [[5, 6]] * 10}) | |||
| _ = ds[40:] | |||
| with self.assertRaises(KeyError): | |||
| with pytest.raises(KeyError): | |||
| ds = DataSet({"x": [[1, 2, 3, 4]] * 10, "y": [[5, 6]] * 10}) | |||
| _ = ds["kom"] | |||
| def test_len_(self): | |||
| ds = DataSet({"x": [[1, 2, 3, 4]] * 40, "y": [[5, 6]] * 40}) | |||
| self.assertEqual(len(ds), 40) | |||
| assert len(ds) == 40 | |||
| ds = DataSet() | |||
| self.assertEqual(len(ds), 0) | |||
| assert len(ds) == 0 | |||
| def test_add_fieldarray(self): | |||
| ds = DataSet({"x": [[1, 2, 3, 4]] * 40, "y": [[5, 6]] * 40}) | |||
| ds.add_fieldarray('z', FieldArray('z', [[7, 8]]*40)) | |||
| self.assertEqual(ds['z'].content, [[7, 8]]*40) | |||
| ds.add_fieldarray('z', FieldArray('z', [[7, 8]] * 40)) | |||
| assert ds['z'].content == [[7, 8]] * 40 | |||
| with self.assertRaises(RuntimeError): | |||
| ds.add_fieldarray('z', FieldArray('z', [[7, 8]]*10)) | |||
| with pytest.raises(RuntimeError): | |||
| ds.add_fieldarray('z', FieldArray('z', [[7, 8]] * 10)) | |||
| with self.assertRaises(TypeError): | |||
| with pytest.raises(TypeError): | |||
| ds.add_fieldarray('z', [1, 2, 4]) | |||
| def test_copy_field(self): | |||
| ds = DataSet({"x": [[1, 2, 3, 4]] * 40, "y": [[5, 6]] * 40}) | |||
| ds.copy_field('x', 'z') | |||
| self.assertEqual(ds['x'].content, ds['z'].content) | |||
| assert ds['x'].content == ds['z'].content | |||
| def test_has_field(self): | |||
| ds = DataSet({"x": [[1, 2, 3, 4]] * 40, "y": [[5, 6]] * 40}) | |||
| self.assertTrue(ds.has_field('x')) | |||
| self.assertFalse(ds.has_field('z')) | |||
| assert ds.has_field('x') == True | |||
| assert ds.has_field('z') == False | |||
| def test_get_field(self): | |||
| ds = DataSet({"x": [[1, 2, 3, 4]] * 40, "y": [[5, 6]] * 40}) | |||
| with self.assertRaises(KeyError): | |||
| with pytest.raises(KeyError): | |||
| ds.get_field('z') | |||
| x_array = ds.get_field('x') | |||
| self.assertEqual(x_array.content, [[1, 2, 3, 4]] * 40) | |||
| assert x_array.content == [[1, 2, 3, 4]] * 40 | |||
| def test_get_all_fields(self): | |||
| ds = DataSet({"x": [[1, 2, 3, 4]] * 40, "y": [[5, 6]] * 40}) | |||
| field_arrays = ds.get_all_fields() | |||
| self.assertEqual(field_arrays["x"], [[1, 2, 3, 4]] * 40) | |||
| self.assertEqual(field_arrays['y'], [[5, 6]] * 40) | |||
| assert field_arrays["x"].content == [[1, 2, 3, 4]] * 40 | |||
| assert field_arrays['y'].content == [[5, 6]] * 40 | |||
| def test_get_field_names(self): | |||
| ds = DataSet({"x": [[1, 2, 3, 4]] * 40, "y": [[5, 6]] * 40}) | |||
| field_names = ds.get_field_names() | |||
| self.assertTrue('x' in field_names) | |||
| self.assertTrue('y' in field_names) | |||
| assert 'x' in field_names | |||
| assert 'y' in field_names | |||
| def test_apply(self): | |||
| ds = DataSet({"x": [[1, 2, 3, 4]] * 4000, "y": [[5, 6]] * 4000}) | |||
| ds.apply(lambda ins: ins["x"][::-1], new_field_name="rx", progress_desc='rx') | |||
| self.assertTrue("rx" in ds.field_arrays) | |||
| self.assertEqual(ds.field_arrays["rx"].content[0], [4, 3, 2, 1]) | |||
| assert ("rx" in ds.field_arrays) == True | |||
| assert ds.field_arrays["rx"].content[0] == [4, 3, 2, 1] | |||
| ds.apply(lambda ins: len(ins["y"]), new_field_name="y", show_progress_bar=False) | |||
| self.assertEqual(ds.field_arrays["y"].content[0], 2) | |||
| assert ds.field_arrays["y"].content[0] == 2 | |||
| res = ds.apply(lambda ins: len(ins["x"]), num_proc=0, progress_desc="len") | |||
| self.assertTrue(isinstance(res, list) and len(res) > 0) | |||
| self.assertTrue(res[0], 4) | |||
| assert (isinstance(res, list) and len(res) > 0) == True | |||
| assert res[0] == 4 | |||
| ds.apply(lambda ins: (len(ins["x"]), "hahaha"), new_field_name="k") | |||
| # expect no exception raised | |||
| @@ -206,6 +206,7 @@ class TestDataSetMethods(unittest.TestCase): | |||
| def modify_inplace(instance): | |||
| instance['words'] = 1 | |||
| ds.apply(modify_inplace) | |||
| # with self.assertRaises(TypeError): | |||
| # ds.apply(modify_inplace) | |||
| @@ -230,48 +231,48 @@ class TestDataSetMethods(unittest.TestCase): | |||
| T.apply_more(func_1) | |||
| # print(T['c'][0, 1, 2]) | |||
| self.assertEqual(list(T["c"].content), [2, 4, 6]) | |||
| self.assertEqual(list(T["d"].content), [1, 4, 9]) | |||
| assert list(T["c"].content) == [2, 4, 6] | |||
| assert list(T["d"].content) == [1, 4, 9] | |||
| res = T.apply_field_more(func_2, "a", modify_fields=False) | |||
| self.assertEqual(list(T["c"].content), [2, 4, 6]) | |||
| self.assertEqual(list(T["d"].content), [1, 4, 9]) | |||
| self.assertEqual(list(res["c"]), [3, 6, 9]) | |||
| self.assertEqual(list(res["d"]), [1, 8, 27]) | |||
| assert list(T["c"].content) == [2, 4, 6] | |||
| assert list(T["d"].content) == [1, 4, 9] | |||
| assert list(res["c"]) == [3, 6, 9] | |||
| assert list(res["d"]) == [1, 8, 27] | |||
| with self.assertRaises(ApplyResultException) as e: | |||
| with pytest.raises(ApplyResultException) as e: | |||
| T.apply_more(func_err_1) | |||
| print(e) | |||
| with self.assertRaises(ApplyResultException) as e: | |||
| with pytest.raises(ApplyResultException) as e: | |||
| T.apply_field_more(func_err_2, "a") | |||
| print(e) | |||
| def test_drop(self): | |||
| ds = DataSet({"x": [[1, 2, 3, 4]] * 40, "y": [[5, 6], [7, 8, 9, 0]] * 20}) | |||
| ds.drop(lambda ins: len(ins["y"]) < 3, inplace=True) | |||
| self.assertEqual(len(ds), 20) | |||
| assert len(ds) == 20 | |||
| def test_contains(self): | |||
| ds = DataSet({"x": [[1, 2, 3, 4]] * 40, "y": [[5, 6]] * 40}) | |||
| self.assertTrue("x" in ds) | |||
| self.assertTrue("y" in ds) | |||
| self.assertFalse("z" in ds) | |||
| assert ("x" in ds) == True | |||
| assert ("y" in ds) == True | |||
| assert ("z" in ds) == False | |||
| def test_rename_field(self): | |||
| ds = DataSet({"x": [[1, 2, 3, 4]] * 10, "y": [[5, 6]] * 10}) | |||
| ds.rename_field("x", "xx") | |||
| self.assertTrue("xx" in ds) | |||
| self.assertFalse("x" in ds) | |||
| assert ("xx" in ds) == True | |||
| assert ("x" in ds) == False | |||
| with self.assertRaises(KeyError): | |||
| with pytest.raises(KeyError): | |||
| ds.rename_field("yyy", "oo") | |||
| def test_split(self): | |||
| ds = DataSet({"x": [[1, 2, 3, 4]] * 10, "y": [[5, 6]] * 10}) | |||
| d1, d2 = ds.split(0.1) | |||
| self.assertEqual(len(d1), len(ds)*0.9) | |||
| self.assertEqual(len(d2), len(ds)*0.1) | |||
| assert len(d2) == (len(ds) * 0.9) | |||
| assert len(d1) == (len(ds) * 0.1) | |||
| def test_add_field_v2(self): | |||
| ds = DataSet({"x": [3, 4]}) | |||
| @@ -282,14 +283,14 @@ class TestDataSetMethods(unittest.TestCase): | |||
| def test_save_load(self): | |||
| ds = DataSet({"x": [[1, 2, 3, 4]] * 10, "y": [[5, 6]] * 10}) | |||
| ds.save("./my_ds.pkl") | |||
| self.assertTrue(os.path.exists("./my_ds.pkl")) | |||
| assert os.path.exists("./my_ds.pkl") == True | |||
| ds_1 = DataSet.load("./my_ds.pkl") | |||
| os.remove("my_ds.pkl") | |||
| def test_add_null(self): | |||
| ds = DataSet() | |||
| with self.assertRaises(RuntimeError) as RE: | |||
| with pytest.raises(RuntimeError) as RE: | |||
| ds.add_field('test', []) | |||
| def test_concat(self): | |||
| @@ -301,16 +302,16 @@ class TestDataSetMethods(unittest.TestCase): | |||
| ds2 = DataSet({"x": [[4, 3, 2, 1] for _ in range(10)], "y": [[6, 5] for _ in range(10)]}) | |||
| ds3 = ds1.concat(ds2) | |||
| self.assertEqual(len(ds3), 20) | |||
| assert len(ds3) == 20 | |||
| self.assertListEqual(ds1[9]['x'], [1, 2, 3, 4]) | |||
| self.assertListEqual(ds1[10]['x'], [4, 3, 2, 1]) | |||
| assert ds1[9]['x'] == [1, 2, 3, 4] | |||
| assert ds1[10]['x'] == [4, 3, 2, 1] | |||
| ds2[0]['x'][0] = 100 | |||
| self.assertEqual(ds3[10]['x'][0], 4) # 不改变copy后的field了 | |||
| assert ds3[10]['x'][0] == 4 # 不改变copy后的field了 | |||
| ds3[10]['x'][0] = -100 | |||
| self.assertEqual(ds2[0]['x'][0], 100) # 不改变copy前的field了 | |||
| assert ds2[0]['x'][0] == 100 # 不改变copy前的field了 | |||
| # 测试inplace | |||
| ds1 = DataSet({"x": [[1, 2, 3, 4] for i in range(10)], "y": [[5, 6] for i in range(10)]}) | |||
| @@ -318,19 +319,19 @@ class TestDataSetMethods(unittest.TestCase): | |||
| ds3 = ds1.concat(ds2, inplace=True) | |||
| ds2[0]['x'][0] = 100 | |||
| self.assertEqual(ds3[10]['x'][0], 4) # 不改变copy后的field了 | |||
| assert ds3[10]['x'][0] == 4 # 不改变copy后的field了 | |||
| ds3[10]['x'][0] = -100 | |||
| self.assertEqual(ds2[0]['x'][0], 100) # 不改变copy前的field了 | |||
| assert ds2[0]['x'][0] == 100 # 不改变copy前的field了 | |||
| ds3[0]['x'][0] = 100 | |||
| self.assertEqual(ds1[0]['x'][0], 100) # 改变copy前的field了 | |||
| assert ds1[0]['x'][0] == 100 # 改变copy前的field了 | |||
| # 测试mapping | |||
| ds1 = DataSet({"x": [[1, 2, 3, 4] for i in range(10)], "y": [[5, 6] for i in range(10)]}) | |||
| ds2 = DataSet({"X": [[4, 3, 2, 1] for i in range(10)], "Y": [[6, 5] for i in range(10)]}) | |||
| ds3 = ds1.concat(ds2, field_mapping={'X': 'x', 'Y': 'y'}) | |||
| self.assertEqual(len(ds3), 20) | |||
| assert len(ds3) == 20 | |||
| # 测试忽略掉多余的 | |||
| ds1 = DataSet({"x": [[1, 2, 3, 4] for i in range(10)], "y": [[5, 6] for i in range(10)]}) | |||
| @@ -340,7 +341,7 @@ class TestDataSetMethods(unittest.TestCase): | |||
| # 测试报错 | |||
| ds1 = DataSet({"x": [[1, 2, 3, 4] for i in range(10)], "y": [[5, 6] for i in range(10)]}) | |||
| ds2 = DataSet({"X": [[4, 3, 2, 1] for i in range(10)]}) | |||
| with self.assertRaises(RuntimeError): | |||
| with pytest.raises(RuntimeError): | |||
| ds3 = ds1.concat(ds2, field_mapping={'X': 'x'}) | |||
| def test_instance_field_disappear_bug(self): | |||
| @@ -348,7 +349,7 @@ class TestDataSetMethods(unittest.TestCase): | |||
| data.copy_field(field_name='raw_chars', new_field_name='chars') | |||
| _data = data[:1] | |||
| for field_name in ['raw_chars', 'target', 'chars']: | |||
| self.assertTrue(_data.has_field(field_name)) | |||
| assert _data.has_field(field_name) == True | |||
| def test_from_pandas(self): | |||
| import pandas as pd | |||
| @@ -356,8 +357,8 @@ class TestDataSetMethods(unittest.TestCase): | |||
| df = pd.DataFrame({'x': [1, 2, 3], 'y': [4, 5, 6]}) | |||
| ds = DataSet.from_pandas(df) | |||
| print(ds) | |||
| self.assertEqual(ds['x'].content, [1, 2, 3]) | |||
| self.assertEqual(ds['y'].content, [4, 5, 6]) | |||
| assert ds['x'].content == [1, 2, 3] | |||
| assert ds['y'].content == [4, 5, 6] | |||
| def test_to_pandas(self): | |||
| ds = DataSet({'x': [1, 2, 3], 'y': [4, 5, 6]}) | |||
| @@ -366,7 +367,7 @@ class TestDataSetMethods(unittest.TestCase): | |||
| def test_to_csv(self): | |||
| ds = DataSet({'x': [1, 2, 3], 'y': [4, 5, 6]}) | |||
| ds.to_csv("1.csv") | |||
| self.assertTrue(os.path.exists("1.csv")) | |||
| assert os.path.exists("1.csv") == True | |||
| os.remove("1.csv") | |||
| def test_add_collate_fn(self): | |||
| @@ -374,27 +375,26 @@ class TestDataSetMethods(unittest.TestCase): | |||
| def collate_fn(item): | |||
| return item | |||
| ds.add_collate_fn(collate_fn) | |||
| self.assertEqual(len(ds.collate_fns.collators), 2) | |||
| ds.add_collate_fn(collate_fn) | |||
| def test_get_collator(self): | |||
| from typing import Callable | |||
| ds = DataSet({'x': [1, 2, 3], 'y': [4, 5, 6]}) | |||
| collate_fn = ds.get_collator() | |||
| self.assertEqual(isinstance(collate_fn, Callable), True) | |||
| assert isinstance(collate_fn, Callable) == True | |||
| def test_add_seq_len(self): | |||
| ds = DataSet({'x': [[1, 2], [2, 3 , 4], [3]], 'y': [4, 5, 6]}) | |||
| ds = DataSet({'x': [[1, 2], [2, 3, 4], [3]], 'y': [4, 5, 6]}) | |||
| ds.add_seq_len('x') | |||
| print(ds) | |||
| def test_set_target(self): | |||
| ds = DataSet({'x': [[1, 2], [2, 3 , 4], [3]], 'y': [4, 5, 6]}) | |||
| ds = DataSet({'x': [[1, 2], [2, 3, 4], [3]], 'y': [4, 5, 6]}) | |||
| ds.set_target('x') | |||
| class TestFieldArrayInit(unittest.TestCase): | |||
| class TestFieldArrayInit: | |||
| """ | |||
| 1) 如果DataSet使用dict初始化,那么在add_field中会构造FieldArray: | |||
| 1.1) 二维list DataSet({"x": [[1, 2], [3, 4]]}) | |||
| @@ -442,7 +442,6 @@ class TestFieldArrayInit(unittest.TestCase): | |||
| # list of array | |||
| fa = FieldArray("x", [np.array([[1, 2], [3, 4]]), np.array([[1, 2], [3, 4]])]) | |||
| def test_init_v8(self): | |||
| # 二维list | |||
| val = np.array([[1, 2], [3, 4]]) | |||
| @@ -450,78 +449,78 @@ class TestFieldArrayInit(unittest.TestCase): | |||
| fa.append(val) | |||
| class TestFieldArray(unittest.TestCase): | |||
| class TestFieldArray: | |||
| def test_main(self): | |||
| fa = FieldArray("x", [1, 2, 3, 4, 5]) | |||
| self.assertEqual(len(fa), 5) | |||
| assert len(fa) == 5 | |||
| fa.append(6) | |||
| self.assertEqual(len(fa), 6) | |||
| assert len(fa) == 6 | |||
| self.assertEqual(fa[-1], 6) | |||
| self.assertEqual(fa[0], 1) | |||
| assert fa[-1] == 6 | |||
| assert fa[0] == 1 | |||
| fa[-1] = 60 | |||
| self.assertEqual(fa[-1], 60) | |||
| assert fa[-1] == 60 | |||
| self.assertEqual(fa.get(0), 1) | |||
| self.assertTrue(isinstance(fa.get([0, 1, 2]), np.ndarray)) | |||
| self.assertListEqual(list(fa.get([0, 1, 2])), [1, 2, 3]) | |||
| assert fa.get(0) == 1 | |||
| assert isinstance(fa.get([0, 1, 2]), np.ndarray) == True | |||
| assert list(fa.get([0, 1, 2])) == [1, 2, 3] | |||
| def test_getitem_v1(self): | |||
| fa = FieldArray("y", [[1.1, 2.2, 3.3, 4.4, 5.5], [1.0, 2.0, 3.0, 4.0, 5.0]]) | |||
| self.assertEqual(fa[0], [1.1, 2.2, 3.3, 4.4, 5.5]) | |||
| assert fa[0] == [1.1, 2.2, 3.3, 4.4, 5.5] | |||
| ans = fa[[0, 1]] | |||
| self.assertTrue(isinstance(ans, np.ndarray)) | |||
| self.assertTrue(isinstance(ans[0], np.ndarray)) | |||
| self.assertEqual(ans[0].tolist(), [1.1, 2.2, 3.3, 4.4, 5.5]) | |||
| self.assertEqual(ans[1].tolist(), [1, 2, 3, 4, 5]) | |||
| self.assertEqual(ans.dtype, np.float64) | |||
| assert isinstance(ans, np.ndarray) == True | |||
| assert isinstance(ans[0], np.ndarray) == True | |||
| assert ans[0].tolist() == [1.1, 2.2, 3.3, 4.4, 5.5] | |||
| assert ans[1].tolist() == [1, 2, 3, 4, 5] | |||
| assert ans.dtype == np.float64 | |||
| def test_getitem_v2(self): | |||
| x = np.random.rand(10, 5) | |||
| fa = FieldArray("my_field", x) | |||
| indices = [0, 1, 3, 4, 6] | |||
| for a, b in zip(fa[indices], x[indices]): | |||
| self.assertListEqual(a.tolist(), b.tolist()) | |||
| assert a.tolist() == b.tolist() | |||
| def test_append(self): | |||
| fa = FieldArray("y", [[1.1, 2.2, 3.3, 4.4, 5.5], [1.0, 2.0, 3.0, 4.0, 5.0]]) | |||
| fa.append([1.2, 2.3, 3.4, 4.5, 5.6]) | |||
| self.assertEqual(len(fa), 3) | |||
| self.assertEqual(fa[2], [1.2, 2.3, 3.4, 4.5, 5.6]) | |||
| assert len(fa) == 3 | |||
| assert fa[2] == [1.2, 2.3, 3.4, 4.5, 5.6] | |||
| def test_pop(self): | |||
| fa = FieldArray("y", [[1.1, 2.2, 3.3, 4.4, 5.5], [1.0, 2.0, 3.0, 4.0, 5.0]]) | |||
| fa.pop(0) | |||
| self.assertEqual(len(fa), 1) | |||
| self.assertEqual(fa[0], [1.0, 2.0, 3.0, 4.0, 5.0]) | |||
| assert len(fa) == 1 | |||
| assert fa[0] == [1.0, 2.0, 3.0, 4.0, 5.0] | |||
| fa[0] = [1.1, 2.2, 3.3, 4.4, 5.5] | |||
| self.assertEqual(fa[0], [1.1, 2.2, 3.3, 4.4, 5.5]) | |||
| assert fa[0] == [1.1, 2.2, 3.3, 4.4, 5.5] | |||
| class TestCase(unittest.TestCase): | |||
| class TestCase: | |||
| def test_init(self): | |||
| fields = {"x": [1, 2, 3], "y": [4, 5, 6]} | |||
| ins = Instance(x=[1, 2, 3], y=[4, 5, 6]) | |||
| self.assertTrue(isinstance(ins.fields, dict)) | |||
| self.assertEqual(ins.fields, fields) | |||
| assert isinstance(ins.fields, dict) == True | |||
| assert ins.fields == fields | |||
| ins = Instance(**fields) | |||
| self.assertEqual(ins.fields, fields) | |||
| assert ins.fields == fields | |||
| def test_add_field(self): | |||
| fields = {"x": [1, 2, 3], "y": [4, 5, 6]} | |||
| ins = Instance(**fields) | |||
| ins.add_field("z", [1, 1, 1]) | |||
| fields.update({"z": [1, 1, 1]}) | |||
| self.assertEqual(ins.fields, fields) | |||
| assert ins.fields == fields | |||
| def test_get_item(self): | |||
| fields = {"x": [1, 2, 3], "y": [4, 5, 6], "z": [1, 1, 1]} | |||
| ins = Instance(**fields) | |||
| self.assertEqual(ins["x"], [1, 2, 3]) | |||
| self.assertEqual(ins["y"], [4, 5, 6]) | |||
| self.assertEqual(ins["z"], [1, 1, 1]) | |||
| assert ins["x"] == [1, 2, 3] | |||
| assert ins["y"] == [4, 5, 6] | |||
| assert ins["z"] == [1, 1, 1] | |||
| def test_repr(self): | |||
| fields = {"x": [1, 2, 3], "y": [4, 5, 6], "z": [1, 1, 1]} | |||
| @@ -10,7 +10,7 @@ from paddle.io import DataLoader, BatchSampler | |||
| from fastNLP.core.drivers.paddle_driver.single_device import PaddleSingleDriver | |||
| from fastNLP.core.samplers.reproducible_sampler import RandomSampler | |||
| from fastNLP.core.samplers import ReproducibleBatchSampler | |||
| from fastNLP.core.samplers import RandomBatchSampler | |||
| from tests.helpers.models.paddle_model import PaddleNormalModel_Classification | |||
| from tests.helpers.datasets.paddle_data import PaddleDataset_MNIST, PaddleRandomDataset | |||
| from fastNLP.core import synchronize_safe_rm | |||
| @@ -153,7 +153,7 @@ class TestSingleDeviceFunction: | |||
| @pytest.mark.parametrize( | |||
| "dist_sampler", | |||
| ["dist", ReproducibleBatchSampler(BatchSampler(PaddleDataset_MNIST("train")), 32, False), RandomSampler(PaddleDataset_MNIST("train"))] | |||
| ["dist", RandomBatchSampler(BatchSampler(PaddleDataset_MNIST("train")), 32, False), RandomSampler(PaddleDataset_MNIST("train"))] | |||
| ) | |||
| @pytest.mark.parametrize( | |||
| "reproducible", | |||
| @@ -7,38 +7,10 @@ import numpy as np | |||
| # print(isinstance((1,), tuple)) | |||
| # exit() | |||
| from fastNLP.core.drivers.torch_driver.dist_utils import fastnlp_torch_all_gather, convert_to_tensors, fastnlp_torch_broadcast_object | |||
| from fastNLP.core.drivers.torch_driver.dist_utils import fastnlp_torch_all_gather, fastnlp_torch_broadcast_object | |||
| from tests.helpers.utils import re_run_current_cmd_for_torch, magic_argv_env_context | |||
| def test_convert_to_tensors(): | |||
| local_rank = 0 | |||
| obj = { | |||
| 'tensor': torch.full(size=(2,), fill_value=local_rank), | |||
| 'numpy': np.full(shape=(1,), fill_value=local_rank), | |||
| 'bool': local_rank % 2 == 0, | |||
| 'float': local_rank + 0.1, | |||
| 'int': local_rank, | |||
| 'dict': { | |||
| 'rank': local_rank | |||
| }, | |||
| 'list': [local_rank] * 2, | |||
| 'str': 'xxx' | |||
| } | |||
| data = convert_to_tensors(obj) | |||
| assert len(data) == len(obj) | |||
| assert (data['tensor'] == obj['tensor']).sum() == 2 | |||
| for name in ['list', 'str']: | |||
| assert len(data[name])==2 and isinstance(data[name][0], torch.Tensor) and \ | |||
| isinstance(data[name][1], torch.Tensor) and data[name][1].ndim==1 | |||
| for name in ['numpy', 'bool', 'float', 'int']: | |||
| assert isinstance(data[name][0], torch.Tensor) and data[name][0].numel()==1 | |||
| assert isinstance(data['dict']['rank'][0], torch.Tensor) and data[name][0].numel() == 1 | |||
| @magic_argv_env_context | |||
| def test_fastnlp_torch_all_gather(): | |||
| os.environ['MASTER_ADDR'] = '127.0.0.1' | |||
| @@ -66,7 +38,7 @@ def test_fastnlp_torch_all_gather(): | |||
| 'tensors': [torch.full(size=(2,), fill_value=local_rank).cuda(), | |||
| torch.full(size=(2,), fill_value=local_rank).cuda()] | |||
| } | |||
| data = fastnlp_torch_all_gather(obj, device=torch.cuda.current_device()) | |||
| data = fastnlp_torch_all_gather(obj) | |||
| world_size = int(os.environ['WORLD_SIZE']) | |||
| assert len(data) == world_size | |||
| for i in range(world_size): | |||
| @@ -81,10 +53,12 @@ def test_fastnlp_torch_all_gather(): | |||
| assert data[i]['tensors'][0][0] == i | |||
| for obj in [1, True, 'xxx']: | |||
| data = fastnlp_torch_all_gather(obj, device=torch.cuda.current_device()) | |||
| data = fastnlp_torch_all_gather(obj) | |||
| assert len(data)==world_size | |||
| assert data[0]==data[1] | |||
| dist.destroy_process_group() | |||
| @magic_argv_env_context | |||
| def test_fastnlp_torch_broadcast_object(): | |||
| os.environ['MASTER_ADDR'] = '127.0.0.1' | |||
| @@ -130,3 +104,4 @@ def test_fastnlp_torch_broadcast_object(): | |||
| for obj in [int(os.environ['LOCAL_RANK']), bool(os.environ['LOCAL_RANK']=='1'), os.environ['LOCAL_RANK']]: | |||
| data = fastnlp_torch_broadcast_object(obj, src=0, device=torch.cuda.current_device()) | |||
| assert int(data)==0 | |||
| dist.destroy_process_group() | |||
| @@ -30,7 +30,7 @@ class SequenceDataSet: | |||
| def check_replace_sampler(driver): | |||
| # dist_sampler 可以选择的有['dist', 'unrepeatdist', None]或者是ReproducibleSampler,ReproducibleBatchSampler | |||
| # dist_sampler 可以选择的有['dist', 'unrepeatdist', None]或者是ReproducibleSampler,RandomBatchSampler | |||
| # reproducible 是 True 和 False | |||
| # 需要 check 返回的 sampler 和 dataloader 都不同了 | |||
| @@ -4,7 +4,7 @@ import numpy as np | |||
| import pytest | |||
| from itertools import chain | |||
| from fastNLP.core.samplers import ReproducibleBatchSampler, BucketedBatchSampler | |||
| from fastNLP.core.samplers import RandomBatchSampler, BucketedBatchSampler | |||
| from fastNLP.core.drivers.torch_driver.utils import replace_batch_sampler | |||
| from tests.helpers.datasets.torch_data import TorchNormalDataset | |||
| @@ -18,7 +18,7 @@ class TestReproducibleBatchSampler: | |||
| before_batch_size = 7 | |||
| dataset = TorchNormalDataset(num_of_data=100) | |||
| dataloader = DataLoader(dataset, batch_size=before_batch_size) | |||
| re_batchsampler = ReproducibleBatchSampler(dataloader.batch_sampler, dataloader.batch_size, drop_last=False) | |||
| re_batchsampler = RandomBatchSampler(dataloader.batch_sampler, dataloader.batch_size, drop_last=False) | |||
| dataloader = replace_batch_sampler(dataloader, re_batchsampler) | |||
| forward_steps = 3 | |||
| @@ -28,15 +28,15 @@ class TestReproducibleBatchSampler: | |||
| # 1. 保存状态 | |||
| _get_re_batchsampler = dataloader.batch_sampler | |||
| assert isinstance(_get_re_batchsampler, ReproducibleBatchSampler) | |||
| assert isinstance(_get_re_batchsampler, RandomBatchSampler) | |||
| state = _get_re_batchsampler.state_dict() | |||
| assert state == {"index_list": array("I", list(range(100))), "data_idx": forward_steps*before_batch_size, | |||
| "sampler_type": "ReproducibleBatchSampler"} | |||
| "sampler_type": "RandomBatchSampler"} | |||
| # 2. 断点重训,重新生成一个 dataloader; | |||
| # 不改变 batch_size; | |||
| dataloader = DataLoader(dataset, batch_size=before_batch_size) | |||
| re_batchsampler = ReproducibleBatchSampler(dataloader.batch_sampler, dataloader.batch_size, drop_last=False) | |||
| re_batchsampler = RandomBatchSampler(dataloader.batch_sampler, dataloader.batch_size, drop_last=False) | |||
| re_batchsampler.load_state_dict(state) | |||
| dataloader = replace_batch_sampler(dataloader, re_batchsampler) | |||
| @@ -53,7 +53,7 @@ class TestReproducibleBatchSampler: | |||
| # 改变 batch_size; | |||
| after_batch_size = 3 | |||
| dataloader = DataLoader(dataset, batch_size=after_batch_size) | |||
| re_batchsampler = ReproducibleBatchSampler(dataloader.batch_sampler, dataloader.batch_size, drop_last=False) | |||
| re_batchsampler = RandomBatchSampler(dataloader.batch_sampler, dataloader.batch_size, drop_last=False) | |||
| re_batchsampler.load_state_dict(state) | |||
| dataloader = replace_batch_sampler(dataloader, re_batchsampler) | |||
| @@ -99,7 +99,7 @@ class TestReproducibleBatchSampler: | |||
| dataset = TorchNormalDataset(num_of_data=100) | |||
| # 开启 shuffle,来检验断点重训后的第二轮的 index list 是不是重新生成的; | |||
| dataloader = DataLoader(dataset, batch_size=before_batch_size, shuffle=True) | |||
| re_batchsampler = ReproducibleBatchSampler(dataloader.batch_sampler, dataloader.batch_size, drop_last=False) | |||
| re_batchsampler = RandomBatchSampler(dataloader.batch_sampler, dataloader.batch_size, drop_last=False) | |||
| dataloader = replace_batch_sampler(dataloader, re_batchsampler) | |||
| # 将一轮的所有数据保存下来,看是否恢复的是正确的; | |||
| @@ -111,13 +111,13 @@ class TestReproducibleBatchSampler: | |||
| # 1. 保存状态 | |||
| _get_re_batchsampler = dataloader.batch_sampler | |||
| assert isinstance(_get_re_batchsampler, ReproducibleBatchSampler) | |||
| assert isinstance(_get_re_batchsampler, RandomBatchSampler) | |||
| state = _get_re_batchsampler.state_dict() | |||
| # 2. 断点重训,重新生成一个 dataloader; | |||
| # 不改变 batch_size; | |||
| dataloader = DataLoader(dataset, batch_size=before_batch_size, shuffle=True) | |||
| re_batchsampler = ReproducibleBatchSampler(dataloader.batch_sampler, dataloader.batch_size, drop_last=False) | |||
| re_batchsampler = RandomBatchSampler(dataloader.batch_sampler, dataloader.batch_size, drop_last=False) | |||
| re_batchsampler.load_state_dict(state) | |||
| dataloader = replace_batch_sampler(dataloader, re_batchsampler) | |||
| @@ -1,18 +1,14 @@ | |||
| import unittest | |||
| from itertools import product | |||
| import numpy as np | |||
| import pytest | |||
| from functools import partial | |||
| from array import array | |||
| from itertools import chain | |||
| from fastNLP.core.samplers.reproducible_sampler import RandomSampler | |||
| from fastNLP.core.drivers.torch_driver.utils import replace_batch_sampler | |||
| from fastNLP.core.samplers.reproducible_sampler import RandomSampler, SortedSampler, SequentialSampler | |||
| from tests.helpers.datasets.torch_data import TorchNormalDataset | |||
| class TestRandomSamplerYh(unittest.TestCase): | |||
| class TestRandomSamplerYh: | |||
| def test_init(self): | |||
| # 测试能否正确初始化 | |||
| dataset = TorchNormalDataset(num_of_data=100) | |||
| @@ -24,7 +20,7 @@ class TestRandomSamplerYh(unittest.TestCase): | |||
| dataset = TorchNormalDataset(num_of_data=100) | |||
| sampler = RandomSampler(dataset) | |||
| for i in sampler: | |||
| with self.assertRaises(AssertionError): | |||
| with pytest.raises(AssertionError): | |||
| sampler.set_distributed(1, 0) | |||
| break | |||
| @@ -37,39 +33,39 @@ class TestRandomSamplerYh(unittest.TestCase): | |||
| dataset = TorchNormalDataset(num_of_data=100) | |||
| sampler = RandomSampler(dataset, shuffle=False) | |||
| sampler.set_distributed(num_replicas=2, rank=0, pad=False) | |||
| self.assertEqual(len(sampler), 50) | |||
| assert len(sampler)==50 | |||
| count = 0 | |||
| for i in sampler: | |||
| self.assertEqual(i%2, 0) | |||
| assert i%2==0 | |||
| count += 1 | |||
| self.assertEqual(count, 50) | |||
| assert count == 50 | |||
| sampler.set_distributed(num_replicas=2, rank=1, pad=False) | |||
| self.assertEqual(len(sampler), 50) | |||
| assert len(sampler)==50 | |||
| count = 0 | |||
| for i in sampler: | |||
| self.assertEqual(i%2, 1) | |||
| assert i%2==1 | |||
| count += 1 | |||
| self.assertEqual(count, 50) | |||
| assert count==50 | |||
| dataset = TorchNormalDataset(num_of_data=101) | |||
| sampler = RandomSampler(dataset, shuffle=False) | |||
| sampler.set_distributed(num_replicas=2, rank=0, pad=True) | |||
| self.assertEqual(len(sampler), 51) | |||
| assert len(sampler)==51 | |||
| count = 0 | |||
| for i in sampler: | |||
| self.assertEqual(i%2, 0) | |||
| assert i%2==0 | |||
| count += 1 | |||
| self.assertEqual(count, 51) | |||
| assert count == 51 | |||
| sampler.set_distributed(num_replicas=2, rank=1, pad=True) | |||
| self.assertEqual(len(sampler), 51) | |||
| assert len(sampler) == 51 | |||
| count = 0 | |||
| for i in sampler: | |||
| if i!=0: | |||
| self.assertEqual(i%2, 1) | |||
| assert i%2==1 | |||
| count += 1 | |||
| self.assertEqual(count, 51) | |||
| assert count == 51 | |||
| def test_state_dict_check_length(self): | |||
| dataset = TorchNormalDataset(num_of_data=100) | |||
| @@ -77,7 +73,7 @@ class TestRandomSamplerYh(unittest.TestCase): | |||
| states = sampler.state_dict() | |||
| new_ds = TorchNormalDataset(num_of_data=10) | |||
| with self.assertRaises(AssertionError): | |||
| with pytest.raises(AssertionError): | |||
| new_sampler = RandomSampler(new_ds) | |||
| new_sampler.load_state_dict(states) | |||
| @@ -85,99 +81,107 @@ class TestRandomSamplerYh(unittest.TestCase): | |||
| new_sampler = RandomSampler(new_ds) | |||
| new_sampler.load_state_dict(states) | |||
| def test_state_dict(self): | |||
| @pytest.mark.parametrize('pad', [True, False]) | |||
| @pytest.mark.parametrize('pre_shuffle', [True, False]) | |||
| @pytest.mark.parametrize('post_shuffle', [True, False]) | |||
| @pytest.mark.parametrize('num_consumed_samples', [0]+np.random.randint(1, 100, size=3).tolist()) | |||
| def test_state_dict(self, pad, pre_shuffle, post_shuffle, num_consumed_samples): | |||
| num_samples = 100 | |||
| dataset = TorchNormalDataset(num_of_data=num_samples) | |||
| # 测试使用 前后shuffle不一致的load操作 | |||
| lst = [0]+np.random.randint(1, num_samples, size=3).tolist() | |||
| for pre_shuffle, post_shuffle, num_consumed_samples in product([True, False], [True, False], | |||
| lst): | |||
| with self.subTest(pre_shuffle=pre_shuffle, post_shuffle=post_shuffle, num_consumed_samples=num_consumed_samples): | |||
| sampler = RandomSampler(dataset, shuffle=pre_shuffle) | |||
| sampler.set_epoch(0) | |||
| already_numbers = set() | |||
| if num_consumed_samples>0: | |||
| for i, j in enumerate(sampler, start=1): | |||
| already_numbers.add(j) | |||
| if i == num_consumed_samples: | |||
| break | |||
| self.assertEqual(len(already_numbers), num_consumed_samples) | |||
| states = sampler.state_dict() | |||
| new_sampler = RandomSampler(dataset, shuffle=post_shuffle) | |||
| new_sampler.load_state_dict(states) | |||
| new_sampler.set_epoch(0) | |||
| for i in new_sampler: | |||
| self.assertNotIn(i, already_numbers) | |||
| # 测试切换成多卡也没有问题 | |||
| other_rank_number = set() | |||
| for rank in range(3): | |||
| new_sampler = RandomSampler(dataset, shuffle=post_shuffle) | |||
| new_sampler.load_state_dict(states) | |||
| new_sampler.set_distributed(num_replicas=3, rank=rank, pad=False) | |||
| new_sampler.set_epoch(0) | |||
| count = 0 | |||
| for i in new_sampler: | |||
| self.assertNotIn(i, other_rank_number) | |||
| other_rank_number.add(i) | |||
| self.assertNotIn(i, already_numbers) | |||
| count += 1 | |||
| def test_state_dict_2(self): | |||
| sampler = RandomSampler(dataset, shuffle=pre_shuffle) | |||
| sampler.set_epoch(0) | |||
| already_numbers = set() | |||
| if num_consumed_samples>0: | |||
| for i, j in enumerate(sampler, start=1): | |||
| already_numbers.add(j) | |||
| if i == num_consumed_samples: | |||
| break | |||
| assert len(already_numbers) == num_consumed_samples | |||
| states = sampler.state_dict() | |||
| new_sampler = RandomSampler(dataset, shuffle=post_shuffle) | |||
| new_sampler.load_state_dict(states) | |||
| new_sampler.set_epoch(0) | |||
| for i in new_sampler: | |||
| assert i not in already_numbers | |||
| # 测试切换成多卡也没有问题 | |||
| other_rank_number = set() | |||
| for rank in range(3): | |||
| new_sampler = RandomSampler(dataset, shuffle=post_shuffle) | |||
| new_sampler.load_state_dict(states) | |||
| new_sampler.set_distributed(num_replicas=3, rank=rank, pad=pad) | |||
| new_sampler.set_epoch(0) | |||
| count = 0 | |||
| seen = 0 | |||
| seen_in_other_rank = 0 | |||
| for i in new_sampler: | |||
| seen_in_other_rank += int(i in other_rank_number) | |||
| other_rank_number.add(i) | |||
| seen += int(i in already_numbers) | |||
| count += 1 | |||
| assert seen <= 1 if pad else seen == 0 | |||
| assert seen_in_other_rank<=1 # 因为pad可能重复 | |||
| @pytest.mark.parametrize('pad', [True, False]) | |||
| @pytest.mark.parametrize('pre_shuffle', [True, False]) | |||
| @pytest.mark.parametrize('post_shuffle', [True, False]) | |||
| @pytest.mark.parametrize('num_consumed_samples', [0]+np.random.randint(1, 100//2, size=3).tolist()) | |||
| def test_state_dict_2(self, pad, pre_shuffle, post_shuffle, num_consumed_samples): | |||
| # 测试一下从多卡切换到单卡,或者切换到不同卡数量的多卡 | |||
| num_samples = 100 | |||
| dataset = TorchNormalDataset(num_of_data=num_samples) | |||
| # 测试使用 前后shuffle不一致的load操作 | |||
| lst = [0]+np.random.randint(1, num_samples//2, size=3).tolist() | |||
| # lst = [30] | |||
| for pre_shuffle, post_shuffle, num_consumed_samples in product([True, False], [True, False], | |||
| lst): | |||
| with self.subTest(pre_shuffle=pre_shuffle, post_shuffle=post_shuffle, num_consumed_samples=num_consumed_samples): | |||
| already_numbers = set() | |||
| sampler = RandomSampler(dataset, shuffle=pre_shuffle, seed=0) | |||
| sampler.set_distributed(num_replicas=2, rank=0) | |||
| sampler.set_epoch(0) | |||
| if num_consumed_samples>0: | |||
| for i, j in enumerate(sampler, start=1): | |||
| already_numbers.add(j) | |||
| if i == num_consumed_samples: | |||
| break | |||
| sampler = RandomSampler(dataset, shuffle=pre_shuffle, seed=0) | |||
| sampler.set_epoch(0) | |||
| sampler.set_distributed(num_replicas=2, rank=1) | |||
| if num_consumed_samples>0: | |||
| for i, j in enumerate(sampler, start=1): | |||
| already_numbers.add(j) | |||
| if i == num_consumed_samples: | |||
| break | |||
| self.assertEqual(len(already_numbers), num_consumed_samples*2) | |||
| states = sampler.state_dict() | |||
| new_sampler = RandomSampler(dataset, shuffle=post_shuffle) | |||
| new_sampler.load_state_dict(states) | |||
| new_sampler.set_epoch(0) | |||
| for i in new_sampler: | |||
| self.assertNotIn(i, already_numbers) | |||
| # 测试切换成多卡也没有问题 | |||
| other_rank_number = set() | |||
| for rank in range(3): | |||
| new_sampler = RandomSampler(dataset, shuffle=post_shuffle) | |||
| new_sampler.load_state_dict(states) | |||
| new_sampler.set_epoch(0) | |||
| new_sampler.set_distributed(num_replicas=3, rank=rank, pad=False) | |||
| count = 0 | |||
| for i in new_sampler: | |||
| self.assertNotIn(i, other_rank_number) | |||
| other_rank_number.add(i) | |||
| self.assertNotIn(i, already_numbers) | |||
| count += 1 | |||
| class TestRandomSampler(unittest.TestCase): | |||
| already_numbers = set() | |||
| sampler = RandomSampler(dataset, shuffle=pre_shuffle, seed=0) | |||
| sampler.set_distributed(num_replicas=2, rank=0) | |||
| sampler.set_epoch(0) | |||
| if num_consumed_samples>0: | |||
| for i, j in enumerate(sampler, start=1): | |||
| already_numbers.add(j) | |||
| if i == num_consumed_samples: | |||
| break | |||
| sampler = RandomSampler(dataset, shuffle=pre_shuffle, seed=0) | |||
| sampler.set_epoch(0) | |||
| sampler.set_distributed(num_replicas=2, rank=1) | |||
| if num_consumed_samples>0: | |||
| for i, j in enumerate(sampler, start=1): | |||
| already_numbers.add(j) | |||
| if i == num_consumed_samples: | |||
| break | |||
| assert len(already_numbers) == num_consumed_samples*2 | |||
| states = sampler.state_dict() | |||
| new_sampler = RandomSampler(dataset, shuffle=post_shuffle) | |||
| new_sampler.load_state_dict(states) | |||
| new_sampler.set_epoch(0) | |||
| for i in new_sampler: | |||
| assert i not in already_numbers | |||
| # 测试切换成多卡也没有问题 | |||
| other_rank_number = set() | |||
| for rank in range(3): | |||
| new_sampler = RandomSampler(dataset, shuffle=post_shuffle) | |||
| new_sampler.load_state_dict(states) | |||
| new_sampler.set_epoch(0) | |||
| new_sampler.set_distributed(num_replicas=3, rank=rank, pad=pad) | |||
| count = 0 | |||
| seen = 0 | |||
| seen_in_other_rank = 0 | |||
| for i in new_sampler: | |||
| seen_in_other_rank += int(i in other_rank_number) | |||
| other_rank_number.add(i) | |||
| seen += int(i in already_numbers) | |||
| count += 1 | |||
| assert seen <= 1 if pad else seen == 0 | |||
| assert seen_in_other_rank<=1 # 因为pad可能重复 | |||
| class TestRandomSampler: | |||
| # 测试单卡; | |||
| def test_seed_work_when_shuffle_is_true(self): | |||
| data_length = 100 | |||
| @@ -360,4 +364,324 @@ class TestRandomSampler(unittest.TestCase): | |||
| ... | |||
| class DatasetWithVaryLength: | |||
| def __init__(self, num_of_data=100, reverse=False): | |||
| self.data = np.arange(num_of_data) | |||
| if reverse: | |||
| self.data = self.data[::-1] | |||
| def __getitem__(self, item): | |||
| return self.data[item] | |||
| def __len__(self): | |||
| return len(self.data) | |||
| class TestSortedSampler: | |||
| def test_single(self): | |||
| num_of_data = 100 | |||
| data = DatasetWithVaryLength(num_of_data) | |||
| sampler = SortedSampler(data, length=data.data) | |||
| indexes = list(sampler) | |||
| assert indexes==list(range(num_of_data-1, -1, -1)) | |||
| @pytest.mark.parametrize('pad', [True, False]) | |||
| @pytest.mark.parametrize('num_replica', [2, 3]) | |||
| @pytest.mark.parametrize('num_of_data', [2, 3, 4, 100]) | |||
| def test_multi(self, pad, num_replica, num_of_data): | |||
| data = DatasetWithVaryLength(num_of_data=num_of_data) | |||
| samplers = [] | |||
| for i in range(num_replica): | |||
| sampler = SortedSampler(dataset=data, length=data.data) | |||
| sampler.set_distributed(num_replica, rank=i, pad=pad) | |||
| samplers.append(sampler) | |||
| # 保证顺序是没乱的 | |||
| already_seen_index = set() | |||
| for sampler in samplers: | |||
| larger_count = 0 # 这里为 0 就可以,因为最后补充的index一定是比较大的数。 | |||
| prev_index = float('inf') | |||
| cur_set = set() | |||
| seen_in_other_rank = 0 | |||
| for index in sampler: | |||
| seen_in_other_rank += int(index in already_seen_index) # 不同的卡不交叉 | |||
| cur_set.add(index) | |||
| larger_count += int(index <= prev_index) | |||
| prev_index = index | |||
| assert larger_count+1 >= len(sampler) # 除了最后一个可能乱掉,其它都必须要保持这个顺序 | |||
| assert seen_in_other_rank <= 1 if pad else seen_in_other_rank == 0 | |||
| already_seen_index.update(cur_set) | |||
| indexes = list(chain(*samplers)) | |||
| indexes = set(indexes) | |||
| if pad: | |||
| assert indexes == set(range(num_of_data)) | |||
| else: | |||
| assert len(indexes) <= num_of_data | |||
| @pytest.mark.parametrize('pad', [True, False]) | |||
| @pytest.mark.parametrize('num_consumed_samples', [0]+np.random.randint(1, 100, size=3).tolist()) | |||
| def test_state_dict(self, pad, num_consumed_samples): | |||
| num_samples = 100 | |||
| dataset = DatasetWithVaryLength(num_of_data=num_samples) | |||
| # 测试使用 前后shuffle不一致的load操作 | |||
| sampler = SortedSampler(dataset, length=dataset.data) | |||
| sampler.set_epoch(0) | |||
| already_numbers = set() | |||
| if num_consumed_samples>0: | |||
| for i, j in enumerate(sampler, start=1): | |||
| if already_numbers: | |||
| assert j<max(already_numbers) | |||
| already_numbers.add(j) | |||
| if i == num_consumed_samples: | |||
| break | |||
| assert len(already_numbers) == num_consumed_samples | |||
| states = sampler.state_dict() | |||
| new_sampler = SortedSampler(dataset, length=dataset.data) | |||
| new_sampler.load_state_dict(states) | |||
| new_sampler.set_epoch(0) | |||
| for i in new_sampler: | |||
| if already_numbers: | |||
| assert i < max(already_numbers) | |||
| assert i not in already_numbers | |||
| # 测试切换成多卡也没有问题 | |||
| other_rank_number = set() | |||
| for rank in range(3): | |||
| new_sampler = SortedSampler(dataset, length=dataset.data) | |||
| new_sampler.load_state_dict(states) | |||
| new_sampler.set_distributed(num_replicas=3, rank=rank, pad=pad) | |||
| new_sampler.set_epoch(0) | |||
| count = 0 | |||
| seen = 0 | |||
| seen_in_other_rank = 0 | |||
| smaller = 0 | |||
| for i in new_sampler: | |||
| if already_numbers: | |||
| smaller += int(i >= max(already_numbers)) | |||
| seen_in_other_rank += int(i in other_rank_number) | |||
| other_rank_number.add(i) | |||
| seen += int(i in already_numbers) | |||
| count += 1 | |||
| assert seen <= 1 if pad else seen == 0 | |||
| assert seen_in_other_rank<=1 # 因为pad可能重复 | |||
| assert smaller<=1 if pad else smaller==0 | |||
| @pytest.mark.parametrize('pad', [True, False]) | |||
| @pytest.mark.parametrize('num_consumed_samples', [0]+np.random.randint(1, 100//2, size=3).tolist()) | |||
| def test_state_dict_2(self, pad, num_consumed_samples): | |||
| # 测试一下从多卡切换到单卡,或者切换到不同卡数量的多卡 | |||
| num_samples = 100 | |||
| dataset = DatasetWithVaryLength(num_of_data=num_samples) | |||
| # 测试使用 前后shuffle不一致的load操作 | |||
| # lst = [30] | |||
| already_numbers = set() | |||
| sampler = SortedSampler(dataset, length=dataset.data) | |||
| sampler.set_distributed(num_replicas=2, rank=0) | |||
| sampler.set_epoch(0) | |||
| if num_consumed_samples>0: | |||
| for i, j in enumerate(sampler, start=1): | |||
| if already_numbers: | |||
| assert j<=max(already_numbers) | |||
| already_numbers.add(j) | |||
| if i == num_consumed_samples: | |||
| break | |||
| sampler = SortedSampler(dataset, length=dataset.data) | |||
| sampler.set_epoch(0) | |||
| sampler.set_distributed(num_replicas=2, rank=1) | |||
| if num_consumed_samples>0: | |||
| for i, j in enumerate(sampler, start=1): | |||
| already_numbers.add(j) | |||
| if i == num_consumed_samples: | |||
| break | |||
| assert len(already_numbers) == num_consumed_samples*2 | |||
| states = sampler.state_dict() | |||
| new_sampler = SortedSampler(dataset, length=dataset.data) | |||
| new_sampler.load_state_dict(states) | |||
| new_sampler.set_epoch(0) | |||
| for i in new_sampler: | |||
| if already_numbers: | |||
| assert i < max(already_numbers) | |||
| assert i not in already_numbers | |||
| # 测试切换成多卡也没有问题 | |||
| other_rank_number = set() | |||
| for rank in range(3): | |||
| new_sampler = SortedSampler(dataset, length=dataset.data) | |||
| new_sampler.load_state_dict(states) | |||
| new_sampler.set_epoch(0) | |||
| new_sampler.set_distributed(num_replicas=3, rank=rank, pad=pad) | |||
| count = 0 | |||
| seen = 0 | |||
| seen_in_other_rank = 0 | |||
| smaller = 0 | |||
| for i in new_sampler: | |||
| if already_numbers: | |||
| smaller += int(i>=max(already_numbers)) | |||
| seen_in_other_rank += int(i in other_rank_number) | |||
| other_rank_number.add(i) | |||
| seen += int(i in already_numbers) | |||
| count += 1 | |||
| assert seen <= 1 if pad else seen == 0 | |||
| assert seen_in_other_rank<=1 # 因为pad可能重复 | |||
| assert smaller <= 1 if pad else smaller == 0 | |||
| class TestSequentialSampler: | |||
| def test_single(self): | |||
| num_of_data = 100 | |||
| data = DatasetWithVaryLength(num_of_data) | |||
| sampler = SequentialSampler(data) | |||
| indexes = list(sampler) | |||
| assert indexes==list(range(num_of_data)) | |||
| @pytest.mark.parametrize('pad', [True, False]) | |||
| @pytest.mark.parametrize('num_replica', [2, 3]) | |||
| @pytest.mark.parametrize('num_of_data', [2, 3, 4, 100]) | |||
| def test_multi(self, pad, num_replica, num_of_data): | |||
| data = DatasetWithVaryLength(num_of_data=num_of_data) | |||
| samplers = [] | |||
| for i in range(num_replica): | |||
| sampler = SequentialSampler(dataset=data) | |||
| sampler.set_distributed(num_replica, rank=i, pad=pad) | |||
| samplers.append(sampler) | |||
| # 保证顺序是没乱的 | |||
| already_seen_index = set() | |||
| for idx, sampler in enumerate(samplers): | |||
| larger_count = 1 | |||
| prev_index = float('inf') | |||
| cur_set = set() | |||
| seen_in_other_rank = 0 | |||
| for index in sampler: | |||
| seen_in_other_rank += int(index in already_seen_index) # 不同的卡不交叉 | |||
| cur_set.add(index) | |||
| larger_count += int(index >= prev_index) | |||
| prev_index = index | |||
| assert larger_count+1 >= len(sampler) # 除了最后一个可能乱掉,其它都必须要保持这个顺序 | |||
| assert seen_in_other_rank <= idx if pad else seen_in_other_rank == 0 | |||
| already_seen_index.update(cur_set) | |||
| indexes = list(chain(*samplers)) | |||
| indexes = set(indexes) | |||
| if pad: | |||
| assert indexes == set(range(num_of_data)) | |||
| else: | |||
| assert len(indexes) <= num_of_data | |||
| @pytest.mark.parametrize('pad', [True, False]) | |||
| @pytest.mark.parametrize('num_consumed_samples', [0]+np.random.randint(1, 100, size=3).tolist()) | |||
| def test_state_dict(self, pad, num_consumed_samples): | |||
| num_samples = 100 | |||
| dataset = DatasetWithVaryLength(num_of_data=num_samples) | |||
| # 测试使用 前后shuffle不一致的load操作 | |||
| sampler = SequentialSampler(dataset=dataset) | |||
| sampler.set_epoch(0) | |||
| already_numbers = set() | |||
| if num_consumed_samples>0: | |||
| for i, j in enumerate(sampler, start=1): | |||
| if already_numbers: | |||
| assert j>max(already_numbers) | |||
| already_numbers.add(j) | |||
| if i == num_consumed_samples: | |||
| break | |||
| assert len(already_numbers) == num_consumed_samples | |||
| states = sampler.state_dict() | |||
| new_sampler = SequentialSampler(dataset=dataset) | |||
| new_sampler.load_state_dict(states) | |||
| new_sampler.set_epoch(0) | |||
| for i in new_sampler: | |||
| if already_numbers: | |||
| assert i > max(already_numbers) | |||
| assert i not in already_numbers | |||
| # 测试切换成多卡也没有问题 | |||
| other_rank_number = set() | |||
| for rank in range(3): | |||
| new_sampler = SequentialSampler(dataset=dataset) | |||
| new_sampler.load_state_dict(states) | |||
| new_sampler.set_distributed(num_replicas=3, rank=rank, pad=pad) | |||
| new_sampler.set_epoch(0) | |||
| count = 0 | |||
| seen = 0 | |||
| seen_in_other_rank = 0 | |||
| smaller = 0 | |||
| for i in new_sampler: | |||
| if already_numbers: | |||
| smaller += int(i <= max(already_numbers)) | |||
| seen_in_other_rank += int(i in other_rank_number) | |||
| other_rank_number.add(i) | |||
| seen += int(i in already_numbers) | |||
| count += 1 | |||
| assert seen <= 1 if pad else seen == 0 | |||
| assert seen_in_other_rank<=rank # 因为pad可能重复 | |||
| assert smaller<=1 if pad else smaller==0 | |||
| @pytest.mark.parametrize('pad', [True, False]) | |||
| @pytest.mark.parametrize('num_consumed_samples', [0]+np.random.randint(1, 100//2, size=3).tolist()) | |||
| def test_state_dict_2(self, pad, num_consumed_samples): | |||
| # 测试一下从多卡切换到单卡,或者切换到不同卡数量的多卡 | |||
| num_samples = 100 | |||
| dataset = DatasetWithVaryLength(num_of_data=num_samples) | |||
| # 测试使用 前后shuffle不一致的load操作 | |||
| # lst = [30] | |||
| already_numbers = set() | |||
| sampler = SequentialSampler(dataset=dataset) | |||
| sampler.set_distributed(num_replicas=2, rank=0) | |||
| sampler.set_epoch(0) | |||
| if num_consumed_samples>0: | |||
| for i, j in enumerate(sampler, start=1): | |||
| if already_numbers: | |||
| assert j>max(already_numbers) | |||
| already_numbers.add(j) | |||
| if i == num_consumed_samples: | |||
| break | |||
| sampler = SequentialSampler(dataset=dataset) | |||
| sampler.set_epoch(0) | |||
| sampler.set_distributed(num_replicas=2, rank=1) | |||
| if num_consumed_samples>0: | |||
| for i, j in enumerate(sampler, start=1): | |||
| already_numbers.add(j) | |||
| if i == num_consumed_samples: | |||
| break | |||
| assert len(already_numbers) == num_consumed_samples*2 | |||
| states = sampler.state_dict() | |||
| new_sampler = SequentialSampler(dataset=dataset) | |||
| new_sampler.load_state_dict(states) | |||
| new_sampler.set_epoch(0) | |||
| for i in new_sampler: | |||
| if already_numbers: | |||
| assert i > max(already_numbers) | |||
| assert i not in already_numbers | |||
| # 测试切换成多卡也没有问题 | |||
| other_rank_number = set() | |||
| for rank in range(3): | |||
| new_sampler = SequentialSampler(dataset=dataset) | |||
| new_sampler.load_state_dict(states) | |||
| new_sampler.set_epoch(0) | |||
| new_sampler.set_distributed(num_replicas=3, rank=rank, pad=pad) | |||
| count = 0 | |||
| seen = 0 | |||
| seen_in_other_rank = 0 | |||
| smaller = 0 | |||
| for i in new_sampler: | |||
| if already_numbers: | |||
| smaller += int(i<max(already_numbers)) | |||
| seen_in_other_rank += int(i in other_rank_number) | |||
| other_rank_number.add(i) | |||
| seen += int(i in already_numbers) | |||
| count += 1 | |||
| assert seen <= 1 if pad else seen == 0 | |||
| assert seen_in_other_rank<=1 # 因为pad可能重复 | |||
| assert smaller <= rank if pad else smaller == 0 | |||
| @@ -2,7 +2,7 @@ from itertools import chain | |||
| import pytest | |||
| from fastNLP.core.samplers import UnrepeatedSampler, UnrepeatedSortedSampler | |||
| from fastNLP.core.samplers import UnrepeatedRandomSampler, UnrepeatedSortedSampler, UnrepeatedSequentialSampler | |||
| class DatasetWithVaryLength: | |||
| @@ -21,7 +21,7 @@ class TestUnrepeatedSampler: | |||
| def test_single(self, shuffle): | |||
| num_of_data = 100 | |||
| data = DatasetWithVaryLength(num_of_data) | |||
| sampler = UnrepeatedSampler(data, shuffle) | |||
| sampler = UnrepeatedRandomSampler(data, shuffle) | |||
| indexes = set(sampler) | |||
| assert indexes==set(range(num_of_data)) | |||
| @@ -32,17 +32,18 @@ class TestUnrepeatedSampler: | |||
| data = DatasetWithVaryLength(num_of_data=num_of_data) | |||
| samplers = [] | |||
| for i in range(num_replica): | |||
| sampler = UnrepeatedSampler(dataset=data, shuffle=shuffle) | |||
| sampler = UnrepeatedRandomSampler(dataset=data, shuffle=shuffle) | |||
| sampler.set_distributed(num_replica, rank=i) | |||
| samplers.append(sampler) | |||
| indexes = set(chain(*samplers)) | |||
| indexes = list(chain(*samplers)) | |||
| assert len(indexes) == num_of_data | |||
| indexes = set(indexes) | |||
| assert indexes==set(range(num_of_data)) | |||
| class TestUnrepeatedSortedSampler: | |||
| @pytest.mark.parametrize('shuffle', [True, False]) | |||
| def test_single(self, shuffle): | |||
| def test_single(self): | |||
| num_of_data = 100 | |||
| data = DatasetWithVaryLength(num_of_data) | |||
| sampler = UnrepeatedSortedSampler(data, length=data.data) | |||
| @@ -51,8 +52,7 @@ class TestUnrepeatedSortedSampler: | |||
| @pytest.mark.parametrize('num_replica', [2, 3]) | |||
| @pytest.mark.parametrize('num_of_data', [2, 3, 4, 100]) | |||
| @pytest.mark.parametrize('shuffle', [False, True]) | |||
| def test_multi(self, num_replica, num_of_data, shuffle): | |||
| def test_multi(self, num_replica, num_of_data): | |||
| data = DatasetWithVaryLength(num_of_data=num_of_data) | |||
| samplers = [] | |||
| for i in range(num_replica): | |||
| @@ -60,5 +60,45 @@ class TestUnrepeatedSortedSampler: | |||
| sampler.set_distributed(num_replica, rank=i) | |||
| samplers.append(sampler) | |||
| indexes = set(chain(*samplers)) | |||
| # 保证顺序是没乱的 | |||
| for sampler in samplers: | |||
| prev_index = float('inf') | |||
| for index in sampler: | |||
| assert index <= prev_index | |||
| prev_index = index | |||
| indexes = list(chain(*samplers)) | |||
| assert len(indexes) == num_of_data # 不同卡之间没有交叉 | |||
| indexes = set(indexes) | |||
| assert indexes==set(range(num_of_data)) | |||
| class TestUnrepeatedSequentialSampler: | |||
| def test_single(self): | |||
| num_of_data = 100 | |||
| data = DatasetWithVaryLength(num_of_data) | |||
| sampler = UnrepeatedSequentialSampler(data, length=data.data) | |||
| indexes = list(sampler) | |||
| assert indexes==list(range(num_of_data)) | |||
| @pytest.mark.parametrize('num_replica', [2, 3]) | |||
| @pytest.mark.parametrize('num_of_data', [2, 3, 4, 100]) | |||
| def test_multi(self, num_replica, num_of_data): | |||
| data = DatasetWithVaryLength(num_of_data=num_of_data) | |||
| samplers = [] | |||
| for i in range(num_replica): | |||
| sampler = UnrepeatedSequentialSampler(dataset=data, length=data.data) | |||
| sampler.set_distributed(num_replica, rank=i) | |||
| samplers.append(sampler) | |||
| # 保证顺序是没乱的 | |||
| for sampler in samplers: | |||
| prev_index = float('-inf') | |||
| for index in sampler: | |||
| assert index>=prev_index | |||
| prev_index = index | |||
| indexes = list(chain(*samplers)) | |||
| assert len(indexes) == num_of_data | |||
| indexes = set(indexes) | |||
| assert indexes == set(range(num_of_data)) | |||