| @@ -107,7 +107,7 @@ class Evaluator: | |||
| ``dropout`` 与 ``batch normalization`` 将会关闭。默认为 ``True``。如果为 ``False``,``fastNLP`` 不会对 ``model`` 的 ``evaluate`` 状态做任何设置。无论 | |||
| 该值是什么,``fastNLP`` 都会在评测后将 ``model`` 的状态设置为 ``train``; | |||
| * *use_dist_sampler* -- | |||
| 是否使用分布式评测的方式。仅当 ``driver`` 为分布式类型时,该参数才有效。默认为根据 ``driver`` 是否支持 | |||
| True / False, 是否使用分布式评测的方式。仅当 ``driver`` 为分布式类型时,该参数才有效。默认为根据 ``driver`` 是否支持 | |||
| 分布式进行设置。如果为 ``True``,将使得每个进程上的 ``dataloader`` 自动使用不同数据,所有进程的数据并集是整个数据集; | |||
| * *output_from_new_proc* -- 等价于 ``Trainer`` 中的 ``output_from_new_proc`` 参数; | |||
| * *progress_bar* -- 等价于 ``Trainer`` 中的 ``progress_bar`` 参数; | |||
| @@ -290,9 +290,9 @@ class Trainer(TrainerEventTrigger): | |||
| driver 实例的 ``model_device`` 才会为 None; | |||
| 3. 对于 paddle,该参数无效; | |||
| * *use_dist_sampler* -- 表示是否使用分布式的 ``sampler``。在多卡时,分布式 ``sampler`` 将自动决定每张卡上读取的 sample ,使得一个 epoch | |||
| * *use_dist_sampler* -- True / False, 表示是否使用分布式的 ``sampler``。在多卡时,分布式 ``sampler`` 将自动决定每张卡上读取的 sample ,使得一个 epoch | |||
| 内所有卡的 sample 加起来为一整个数据集的 sample。默认会根据 driver 是否为分布式进行设置。 | |||
| * *evaluate_use_dist_sampler* -- 表示在 ``Evaluator`` 中在使用分布式的时候是否将 dataloader 的 ``sampler`` 替换为分布式的 ``sampler``; | |||
| * *evaluate_use_dist_sampler* -- True / False, 表示在 ``Evaluator`` 中在使用分布式的时候是否将 dataloader 的 ``sampler`` 替换为分布式的 ``sampler``; | |||
| 不传入该值时,该值与 ``use_dist_sampler`` 参数保持一致; | |||
| * *output_from_new_proc* -- 应当为一个字符串,表示在多进程的 driver 中其它进程的输出流应当被做如何处理;其值应当为以下之一: | |||
| ["all", "ignore", "only_error"];当该参数的值不是以上值时,该值应当表示一个文件夹的名字,我们会将其他 rank 的输出流重定向到 | |||
| @@ -565,6 +565,13 @@ class TorchDDPDriver(TorchDriver): | |||
| ) | |||
| return replace_sampler(dataloader, sampler) | |||
| else: | |||
| if type(args.batch_sampler) is not BatchSampler or (type(args.sampler) not in {torch.utils.data.RandomSampler, | |||
| torch.utils.data.SequentialSampler}): | |||
| raise TypeError("Using customized ``batch_sampler`` or ``sampler`` with 'DDP' may cause unseen problems, cause" | |||
| "we will substitute your dataloader's sampler into our ``fastNLP.RandomSampler``. You should make" | |||
| "your customized sampler being able to be used in distributed setting before you initialize ``Trainer`` by yourself," | |||
| "and then set the parameter ``use_dist_sampler`` of ``Trainer`` to ``False``.") | |||
| sampler = RandomSampler( | |||
| dataset=args.dataset, | |||
| shuffle=args.shuffle, | |||
| @@ -582,6 +589,7 @@ class TorchDDPDriver(TorchDriver): | |||
| if isinstance(args.sampler, ReproducibleSampler): | |||
| sampler = conversion_between_reproducible_and_unrepeated_sampler(args.sampler) | |||
| elif not isinstance(args.sampler, UnrepeatedSampler): | |||
| # todo same as dist | |||
| sampler = UnrepeatedSequentialSampler( | |||
| dataset=args.dataset | |||
| ) | |||
| @@ -14,7 +14,7 @@ from fastNLP.envs import ( | |||
| FASTNLP_BACKEND_LAUNCH, | |||
| FASTNLP_GLOBAL_SEED, | |||
| ) | |||
| from fastNLP.core.samplers import re_instantiate_sampler | |||
| from fastNLP.core.samplers import re_instantiate_sampler, ReproducibleBatchSampler | |||
| from fastNLP.core.utils import auto_param_call | |||
| from fastNLP.core.log import logger | |||
| @@ -23,7 +23,6 @@ if _NEED_IMPORT_TORCH: | |||
| # import torch.nn as nn | |||
| from torch.nn import Module | |||
| from torch.utils.data import DataLoader, BatchSampler | |||
| from torch.utils.data.sampler import Sampler | |||
| else: | |||
| from fastNLP.core.utils.dummy_class import DummyClass as Module | |||
| @@ -201,7 +200,10 @@ def replace_sampler(dataloader: "DataLoader", sampler): | |||
| non_default_params.add("dataset") | |||
| reconstruct_args = {k: v for k, v in instance_attrs.items() if k in non_default_params} | |||
| reconstruct_args.update(_dataloader_init_kwargs_resolve_sampler(dataloader, sampler)) | |||
| batch_sampler = getattr(dataloader, "batch_sampler") | |||
| if batch_sampler is not None and isinstance(batch_sampler, ReproducibleBatchSampler): | |||
| raise RuntimeError("It should not be running here, please report a bug to us.") | |||
| required_args = { | |||
| p.name | |||
| @@ -243,28 +245,6 @@ def replace_sampler(dataloader: "DataLoader", sampler): | |||
| return type(dataloader)(**reconstruct_args) | |||
| def _dataloader_init_kwargs_resolve_sampler( | |||
| dataloader: "DataLoader", sampler: Optional["Sampler"] | |||
| ) -> Dict[str, Any]: | |||
| r""" | |||
| 此函数用于处理与 DataLoader 关联的采样器、batch_sampler 参数重新实例化; | |||
| """ | |||
| batch_sampler = getattr(dataloader, "batch_sampler") | |||
| # checking the batch sampler type is different than PyTorch default. | |||
| if batch_sampler is not None and not isinstance(batch_sampler, BatchSampler): | |||
| batch_sampler = re_instantiate_sampler(batch_sampler) | |||
| return { | |||
| "sampler": None, | |||
| "shuffle": False, | |||
| "batch_sampler": batch_sampler, | |||
| "batch_size": 1, | |||
| "drop_last": False, | |||
| } | |||
| return {"sampler": sampler, "shuffle": False, "batch_sampler": None} | |||
| def replace_batch_sampler(dataloader, new_batch_sampler): | |||
| r""" | |||
| 替换一个 dataloader 的 batch_sampler; | |||