| @@ -364,16 +364,16 @@ class _MetricsWrapper: | |||||
| else: | else: | ||||
| args.append(batch) | args.append(batch) | ||||
| if not isinstance(outputs, dict): | if not isinstance(outputs, dict): | ||||
| raise RuntimeError(f"The output of your model is of type:`{type(batch)}`, please either directly" | |||||
| raise RuntimeError(f"The output of your model is of type:`{type(outputs)}`, please either directly" | |||||
| f" return a dict from your model or use `output_mapping` to convert it into dict type.") | f" return a dict from your model or use `output_mapping` to convert it into dict type.") | ||||
| if isinstance(metric, Metric): | if isinstance(metric, Metric): | ||||
| auto_param_call(metric.update, batch, *args) | |||||
| auto_param_call(metric.update, outputs, *args) | |||||
| elif _is_torchmetrics_metric(metric): | elif _is_torchmetrics_metric(metric): | ||||
| auto_param_call(metric.update, batch, *args) | |||||
| auto_param_call(metric.update, outputs, *args) | |||||
| elif _is_allennlp_metric(metric): | elif _is_allennlp_metric(metric): | ||||
| auto_param_call(metric.__call__, batch, *args) | |||||
| auto_param_call(metric.__call__, outputs, *args) | |||||
| elif _is_paddle_metric(metric): | elif _is_paddle_metric(metric): | ||||
| res = auto_param_call(metric.compute, batch, *args) | |||||
| res = auto_param_call(metric.compute, outputs, *args) | |||||
| metric.update(res) | metric.update(res) | ||||
| def reset(self): | def reset(self): | ||||
| @@ -105,8 +105,8 @@ class Trainer(TrainerEventTrigger): | |||||
| 如果 batch 是一个 `Dict`,那么我们会把 batch 中同样在 output_mapping 中的 key 修改为 output_mapping 的对应 key 的 value; | 如果 batch 是一个 `Dict`,那么我们会把 batch 中同样在 output_mapping 中的 key 修改为 output_mapping 的对应 key 的 value; | ||||
| 如果 batch 是一个 `dataclass`,那么我们会先将该 dataclass 转换为一个 Dict,然后再进行上述转换; | 如果 batch 是一个 `dataclass`,那么我们会先将该 dataclass 转换为一个 Dict,然后再进行上述转换; | ||||
| :param model_wo_auto_param_call: 是否关闭在训练时调用我们的 auto_param_call 来自动匹配 batch 和 forward 函数的参数的行为; | :param model_wo_auto_param_call: 是否关闭在训练时调用我们的 auto_param_call 来自动匹配 batch 和 forward 函数的参数的行为; | ||||
| 如果该值为 True,并且当 batch 为字典时,我们会根据 forward 所需要的参数从 batch 中提取对应的对象,传入到 forward 函数中;如果该值 | |||||
| 为 False,那么我们会将 batch 直接透传给 forward 函数。注意上述逻辑同样应用于 `train_step`, `validate_step` 和 `test_step`; | |||||
| 如果该值为 False,并且当 batch 为字典时,我们会根据 forward 所需要的参数从 batch 中提取对应的对象,传入到 forward 函数中;如果该值 | |||||
| 为 True,那么我们会将 batch 直接透传给模型。注意该参数应用于 `train_step`, `validate_step` 和 `test_step`; | |||||
| :param accumulation_steps: 梯度累积的步数,表示每隔几个 batch 优化器迭代一次;默认为 1; | :param accumulation_steps: 梯度累积的步数,表示每隔几个 batch 优化器迭代一次;默认为 1; | ||||
| :param fp16: 是否开启混合精度训练;默认为 False; | :param fp16: 是否开启混合精度训练;默认为 False; | ||||
| :param monitor: 当存在 validate_dataloaders 时,默认的 monitor metric 的名字。传入的 callback 如果有 monitor 参数且没有 | :param monitor: 当存在 validate_dataloaders 时,默认的 monitor metric 的名字。传入的 callback 如果有 monitor 参数且没有 | ||||
| @@ -325,6 +325,8 @@ class Trainer(TrainerEventTrigger): | |||||
| try: | try: | ||||
| while self.cur_epoch_idx < self.n_epochs: | while self.cur_epoch_idx < self.n_epochs: | ||||
| # 这个是防止在 Trainer.load 之后还没结束当前 epoch 又继续 save | |||||
| self.start_batch_idx_in_epoch = self.trainer_state.batch_idx_in_epoch | |||||
| self.driver.set_model_mode("train") | self.driver.set_model_mode("train") | ||||
| self.on_train_epoch_begin() | self.on_train_epoch_begin() | ||||
| self.driver.set_sampler_epoch(self.dataloader, self.cur_epoch_idx) | self.driver.set_sampler_epoch(self.dataloader, self.cur_epoch_idx) | ||||
| @@ -598,7 +600,9 @@ class Trainer(TrainerEventTrigger): | |||||
| # 1. callback states 和 每一个callback的具体 callback 函数的 filter 的状态; | # 1. callback states 和 每一个callback的具体 callback 函数的 filter 的状态; | ||||
| # 2. trainer_state; | # 2. trainer_state; | ||||
| states = {"callback_states": self.on_save_checkpoint(), | states = {"callback_states": self.on_save_checkpoint(), | ||||
| "trainer_state": self.trainer_state.state_dict()} | |||||
| "trainer_state": self.trainer_state.state_dict(), | |||||
| 'num_consumed_batches': self.batch_idx_in_epoch - getattr(self, 'start_batch_idx_in_epoch', 0) | |||||
| } | |||||
| # 3. validate filter state; | # 3. validate filter state; | ||||
| if self.evaluator is not None: | if self.evaluator is not None: | ||||
| @@ -675,9 +679,13 @@ class Trainer(TrainerEventTrigger): | |||||
| # 这里的原则就是应当使得 '还会产生的batch数量' + 'batch_idx_in_epoch' = '原来不断点训练的batch的总数'。其中由于 | # 这里的原则就是应当使得 '还会产生的batch数量' + 'batch_idx_in_epoch' = '原来不断点训练的batch的总数'。其中由于 | ||||
| # '还会产生的batch数量' 是由还剩多少 sample 决定的,因此只能通过调整 'batch_idx_in_epoch' 使得等式成立 | # '还会产生的batch数量' 是由还剩多少 sample 决定的,因此只能通过调整 'batch_idx_in_epoch' 使得等式成立 | ||||
| self.trainer_state.batch_idx_in_epoch = states.pop('batch_idx_in_epoch') | self.trainer_state.batch_idx_in_epoch = states.pop('batch_idx_in_epoch') | ||||
| self.trainer_state.global_forward_batches = self.num_batches_per_epoch * self.cur_epoch_idx + \ | |||||
| self.batch_idx_in_epoch | |||||
| # 这个是防止用户在 Trainer.load 之后还没结束当前 epoch 又继续 save | |||||
| self.start_batch_idx_in_epoch = self.trainer_state.batch_idx_in_epoch | |||||
| # 5. 恢复所有 callback 的状态; | # 5. 恢复所有 callback 的状态; | ||||
| self.train_stepeckpoint(states["callback_states"]) | |||||
| self.on_load_checkpoint(states["callback_states"]) | |||||
| self.driver.barrier() | self.driver.barrier() | ||||
| @@ -60,7 +60,7 @@ class TrainerState: | |||||
| cur_epoch_idx: 当前正在运行第几个 epoch; | cur_epoch_idx: 当前正在运行第几个 epoch; | ||||
| global_forward_batches: 当前模型总共 forward 了多少个 step; | global_forward_batches: 当前模型总共 forward 了多少个 step; | ||||
| batch_idx_in_epoch: 训练中在当前 epoch 的第几个 step; | batch_idx_in_epoch: 训练中在当前 epoch 的第几个 step; | ||||
| total_batches: 每一个 epoch 会 forward 多少个 step; | |||||
| num_batches_per_epoch: 每一个 epoch 会 forward 多少个 step; | |||||
| total_batches: 完整训练过程会 forward 的 step 数量,注意 total_batches = total_batches * n_epochs; | total_batches: 完整训练过程会 forward 的 step 数量,注意 total_batches = total_batches * n_epochs; | ||||
| """ | """ | ||||
| n_epochs: Optional[int] = None # 无论如何重新算 | n_epochs: Optional[int] = None # 无论如何重新算 | ||||
| @@ -194,9 +194,20 @@ class TorchDriver(Driver): | |||||
| sampler = dataloader_args.sampler | sampler = dataloader_args.sampler | ||||
| else: | else: | ||||
| raise RuntimeError("This condition is not supposed to appear. Please report a bug to us.") | raise RuntimeError("This condition is not supposed to appear. Please report a bug to us.") | ||||
| num_consumed_batches = states.pop('num_consumed_batches') | |||||
| if hasattr(sampler, 'state_dict') and callable(sampler.state_dict): | if hasattr(sampler, 'state_dict') and callable(sampler.state_dict): | ||||
| states['sampler_states'] = sampler.state_dict() | |||||
| sampler_states = sampler.state_dict() | |||||
| # 如果有,需要针对 num_consumed_samples 做特殊的处理。因为DataLoader存在预取行为,直接使用sampler中的num_consumed_samples | |||||
| # 会造成多余实际消耗的问题。 | |||||
| num_consumed_samples_array = sampler_states.pop('num_consumed_samples_array', None) | |||||
| if num_consumed_samples_array is not None: | |||||
| if isinstance(sampler, ReproducibleSampler): # 如果是 sampler 的话,需要考虑 batch_size 。 | |||||
| try: | |||||
| num_consumed_batches = num_consumed_batches * dataloader_args.batch_size | |||||
| except: # 有可能 batch_size 为 None,就只有损失精度了 | |||||
| num_consumed_batches = sampler_states['num_consumed_samples'] | |||||
| sampler_states['num_consumed_samples'] = num_consumed_samples_array[num_consumed_batches] | |||||
| assert sampler_states['num_consumed_samples'] != -1, "This is a bug, please report." | |||||
| else: | else: | ||||
| raise RuntimeError( | raise RuntimeError( | ||||
| 'The sampler has no `state_dict()` method, it will fail to recover to the specific batch.') | 'The sampler has no `state_dict()` method, it will fail to recover to the specific batch.') | ||||
| @@ -4,16 +4,18 @@ __all__ = [ | |||||
| ] | ] | ||||
| import math | import math | ||||
| from array import array | |||||
| from copy import deepcopy | from copy import deepcopy | ||||
| from typing import Dict, Union, List | from typing import Dict, Union, List | ||||
| from itertools import chain | from itertools import chain | ||||
| import os | |||||
| import numpy as np | import numpy as np | ||||
| from fastNLP.core.dataset import DataSet | from fastNLP.core.dataset import DataSet | ||||
| from fastNLP.core.log import logger | from fastNLP.core.log import logger | ||||
| from .utils import create_array, NumConsumedSamplesArray | |||||
| from abc import abstractmethod | from abc import abstractmethod | ||||
| from fastNLP.envs.env import FASTNLP_DEQUE_SIZE | |||||
| class ReproducibleBatchSampler: | class ReproducibleBatchSampler: | ||||
| @@ -34,6 +36,13 @@ class ReproducibleBatchSampler: | |||||
| @abstractmethod | @abstractmethod | ||||
| def state_dict(self): | def state_dict(self): | ||||
| """ | |||||
| 由于现在的DataLoader都存在预取数据的功能,因此请参考 RandomBatchSampler 中 states 里面 num_consumed_samples_array 的实现 | |||||
| 正确设置该值。其思想是记录每个 index 对应的 num_consumed_samples ,在 Trainer.save 时会根据 Trainer 中的真实 forward | |||||
| 了多少个 sample 从 num_consumed_samples_array 取出对应的 num_consumed_samples 进行存储。 | |||||
| :return: | |||||
| """ | |||||
| raise NotImplementedError("Each specific batch_sampler should implement its own `state_dict` method.") | raise NotImplementedError("Each specific batch_sampler should implement its own `state_dict` method.") | ||||
| @abstractmethod | @abstractmethod | ||||
| @@ -67,7 +76,7 @@ class RandomBatchSampler(ReproducibleBatchSampler): | |||||
| self.batch_size = batch_size | self.batch_size = batch_size | ||||
| self.drop_last = drop_last | self.drop_last = drop_last | ||||
| self.data_idx = kwargs.get("data_idx", 0) | |||||
| self.num_consumed_samples = kwargs.get("num_consumed_samples", 0) | |||||
| self.index_list = kwargs.get("index_list", self._iterate_sampler()) | self.index_list = kwargs.get("index_list", self._iterate_sampler()) | ||||
| self.need_reinitialize = kwargs.get("need_reinitialize", False) | self.need_reinitialize = kwargs.get("need_reinitialize", False) | ||||
| @@ -80,36 +89,40 @@ class RandomBatchSampler(ReproducibleBatchSampler): | |||||
| # 说明是在初始化时传入的是一个 sampler,理论上对应于 dataloader 在初始化时没有 batch_size,也没有 batch_sampler 的情况; | # 说明是在初始化时传入的是一个 sampler,理论上对应于 dataloader 在初始化时没有 batch_size,也没有 batch_sampler 的情况; | ||||
| else: | else: | ||||
| _index_lst.append(idx) | _index_lst.append(idx) | ||||
| # 64 位机器的 unsigned int 为 4 个字节,能表示的最大大小为 4294967295; | |||||
| if len(_index_lst) > 4294967295: | |||||
| # 注意 self.index_list 内存放的是全部数据的 index; | |||||
| # unsigned long | |||||
| _index_lst = array("L", _index_lst) | |||||
| else: | |||||
| # unsigned int | |||||
| _index_lst = array("I", _index_lst) | |||||
| _index_lst = create_array(len(_index_lst), _index_lst) | |||||
| return _index_lst | return _index_lst | ||||
| def __iter__(self): | def __iter__(self): | ||||
| if self.need_reinitialize: | if self.need_reinitialize: | ||||
| self.index_list = self._iterate_sampler() | self.index_list = self._iterate_sampler() | ||||
| self.data_idx = 0 | |||||
| self.num_consumed_samples = 0 | |||||
| else: | else: | ||||
| self.need_reinitialize = True | self.need_reinitialize = True | ||||
| batch = [] | batch = [] | ||||
| if self.data_idx: | |||||
| index_list = self.index_list[self.data_idx:] | |||||
| if self.num_consumed_samples: | |||||
| index_list = self.index_list[self.num_consumed_samples:] | |||||
| else: | else: | ||||
| index_list = self.index_list | index_list = self.index_list | ||||
| # 记住每个 batch 对应的 consumed_samples, 需要这个原因是由于现在的 dataloader 都存在预取数据的设计,需要再结合Trainer中 | |||||
| # batch_idx_in_epoch 才能最终确定实际消耗的数据。这个变量需要记录每次yield出去时的真实 num_consumed_samples 的数值。 | |||||
| self.num_consumed_samples_array = NumConsumedSamplesArray(buffer_size=os.environ.get(FASTNLP_DEQUE_SIZE, 30), | |||||
| num_consumed_samples=self.num_consumed_samples) | |||||
| for idx in index_list: | for idx in index_list: | ||||
| batch.append(idx) | batch.append(idx) | ||||
| self.data_idx += 1 | |||||
| if len(batch) == self.batch_size: | if len(batch) == self.batch_size: | ||||
| self.num_consumed_samples += self.batch_size # [16, 32, 48, 64,..., ] | |||||
| self.num_consumed_samples_array.push(self.num_consumed_samples) | |||||
| yield batch | yield batch | ||||
| batch = [] | batch = [] | ||||
| if len(batch) > 0 and not self.drop_last: | if len(batch) > 0 and not self.drop_last: | ||||
| self.num_consumed_samples += len(batch) | |||||
| self.num_consumed_samples_array.push(self.num_consumed_samples) | |||||
| yield batch | yield batch | ||||
| # 需要重置防止边界条件问题 | |||||
| self.num_consumed_samples = 0 | |||||
| delattr(self, 'num_consumed_samples_array') | |||||
| def __len__(self) -> int: | def __len__(self) -> int: | ||||
| if self.drop_last: | if self.drop_last: | ||||
| @@ -118,7 +131,13 @@ class RandomBatchSampler(ReproducibleBatchSampler): | |||||
| return (len(self.index_list) + self.batch_size - 1) // self.batch_size | return (len(self.index_list) + self.batch_size - 1) // self.batch_size | ||||
| def state_dict(self) -> Dict: | def state_dict(self) -> Dict: | ||||
| return {"index_list": deepcopy(self.index_list), "data_idx": self.data_idx, 'sampler_type': self.__class__.__name__} | |||||
| states = { | |||||
| "index_list": deepcopy(self.index_list), | |||||
| "num_consumed_samples": self.num_consumed_samples, | |||||
| 'sampler_type': self.__class__.__name__ | |||||
| } | |||||
| states['num_consumed_samples_array'] = getattr(self, 'num_consumed_samples_array', None) | |||||
| return states | |||||
| def load_state_dict(self, states: Dict): | def load_state_dict(self, states: Dict): | ||||
| assert states['sampler_type'] == self.__class__.__name__, f"The sampler type in checkpoint is {states['sampler_type']}," \ | assert states['sampler_type'] == self.__class__.__name__, f"The sampler type in checkpoint is {states['sampler_type']}," \ | ||||
| @@ -128,7 +147,7 @@ class RandomBatchSampler(ReproducibleBatchSampler): | |||||
| assert len(_index_list) == len(self.index_list), "The number of samples is different between the checkpoint " \ | assert len(_index_list) == len(self.index_list), "The number of samples is different between the checkpoint " \ | ||||
| "record and current dataset." | "record and current dataset." | ||||
| self.index_list = _index_list | self.index_list = _index_list | ||||
| self.data_idx = states["data_idx"] | |||||
| self.num_consumed_samples = states["num_consumed_samples"] | |||||
| self.need_reinitialize = False | self.need_reinitialize = False | ||||
| def set_distributed(self, num_replicas, rank, pad=True): | def set_distributed(self, num_replicas, rank, pad=True): | ||||
| @@ -141,10 +160,10 @@ class RandomBatchSampler(ReproducibleBatchSampler): | |||||
| @property | @property | ||||
| def batch_idx_in_epoch(self): | def batch_idx_in_epoch(self): | ||||
| if self.drop_last: | if self.drop_last: | ||||
| return len(self.index_list) // self.batch_size - (len(self.index_list) - self.data_idx) // self.batch_size | |||||
| return len(self.index_list) // self.batch_size - (len(self.index_list) - self.num_consumed_samples) // self.batch_size | |||||
| else: | else: | ||||
| return (len(self.index_list) + self.batch_size - 1) // self.batch_size - \ | return (len(self.index_list) + self.batch_size - 1) // self.batch_size - \ | ||||
| (len(self.index_list) - self.data_idx + self.batch_size - 1) // self.batch_size | |||||
| (len(self.index_list) - self.num_consumed_samples + self.batch_size - 1) // self.batch_size | |||||
| class BucketedBatchSampler(ReproducibleBatchSampler): | class BucketedBatchSampler(ReproducibleBatchSampler): | ||||
| @@ -180,7 +199,6 @@ class BucketedBatchSampler(ReproducibleBatchSampler): | |||||
| self.length = np.array(length, dtype=int) # 按照长到短排列的序号。 | self.length = np.array(length, dtype=int) # 按照长到短排列的序号。 | ||||
| self.sorted_indices = np.argsort(self.length)[::-1] # 按长度从高到低排序的 | self.sorted_indices = np.argsort(self.length)[::-1] # 按长度从高到低排序的 | ||||
| self.batch_size = batch_size | self.batch_size = batch_size | ||||
| self.num_batch_per_bucket = num_batch_per_bucket | self.num_batch_per_bucket = num_batch_per_bucket | ||||
| self.shuffle = shuffle | self.shuffle = shuffle | ||||
| @@ -212,13 +230,13 @@ class BucketedBatchSampler(ReproducibleBatchSampler): | |||||
| self.rank = rank | self.rank = rank | ||||
| self.pad = pad | self.pad = pad | ||||
| num_samples = (len(self.dataset)+self.num_replicas-1)//self.num_replicas*self.num_replicas if pad \ | |||||
| else len(self.dataset) | |||||
| if self.drop_last: | |||||
| assert self.num_replicas*self.batch_size<=num_samples, "The number of samples should be greater " \ | |||||
| "than the number of replicates multiplied " \ | |||||
| "with batch_size when drop_last=True." | |||||
| # num_samples = (len(self.dataset)+self.num_replicas-1)//self.num_replicas*self.num_replicas if pad \ | |||||
| # else len(self.dataset) | |||||
| # | |||||
| # if self.drop_last: | |||||
| # assert self.num_replicas*self.batch_size<=num_samples, "The number of samples should be greater " \ | |||||
| # "than the number of replicates multiplied " \ | |||||
| # "with batch_size when drop_last=True." | |||||
| return self | return self | ||||
| @@ -243,7 +261,7 @@ class BucketedBatchSampler(ReproducibleBatchSampler): | |||||
| return math.ceil((len(self.dataset) - num_consumed_samples) / self.num_replicas) if \ | return math.ceil((len(self.dataset) - num_consumed_samples) / self.num_replicas) if \ | ||||
| self.pad else math.floor(((len(self.dataset) - num_consumed_samples) / self.num_replicas)) | self.pad else math.floor(((len(self.dataset) - num_consumed_samples) / self.num_replicas)) | ||||
| def __len__(self): | |||||
| def __len__(self)->int: | |||||
| """ | """ | ||||
| 返回当前 sampler 还会返回多少个 batch 的数据 | 返回当前 sampler 还会返回多少个 batch 的数据 | ||||
| @@ -309,11 +327,15 @@ class BucketedBatchSampler(ReproducibleBatchSampler): | |||||
| if self.drop_last and len(batches) >= 1 and len(batches[-1]) < self.batch_size: | if self.drop_last and len(batches) >= 1 and len(batches[-1]) < self.batch_size: | ||||
| batches = batches[:-1] | batches = batches[:-1] | ||||
| self.num_consumed_samples_array = NumConsumedSamplesArray(buffer_size=os.environ.get(FASTNLP_DEQUE_SIZE, 30), | |||||
| num_consumed_samples=self.num_consumed_samples) | |||||
| for batch in batches: | for batch in batches: | ||||
| self.num_consumed_samples += self.num_replicas * len(batch) | self.num_consumed_samples += self.num_replicas * len(batch) | ||||
| self.num_consumed_samples_array.push(self.num_consumed_samples) | |||||
| yield list(map(int, batch)) | yield list(map(int, batch)) | ||||
| self.during_iter = False | self.during_iter = False | ||||
| self.num_consumed_samples = 0 | self.num_consumed_samples = 0 | ||||
| delattr(self, 'num_consumed_samples_array') | |||||
| self.old_batch_size = self.batch_size | self.old_batch_size = self.batch_size | ||||
| self.old_num_batch_per_bucket = self.num_batch_per_bucket | self.old_num_batch_per_bucket = self.num_batch_per_bucket | ||||
| self.old_num_replicas = self.num_replicas | self.old_num_replicas = self.num_replicas | ||||
| @@ -376,10 +398,12 @@ class BucketedBatchSampler(ReproducibleBatchSampler): | |||||
| 'num_batch_per_bucket': self.num_batch_per_bucket, | 'num_batch_per_bucket': self.num_batch_per_bucket, | ||||
| 'num_replicas': self.num_replicas | 'num_replicas': self.num_replicas | ||||
| } | } | ||||
| states['num_consumed_samples_array'] = getattr(self, 'num_consumed_samples_array', None) | |||||
| return states | return states | ||||
| def load_state_dict(self, states: Dict): | def load_state_dict(self, states: Dict): | ||||
| # 如果 self.during_iter 是 True,那么 data_idx 一定是 0; | |||||
| # 如果 self.during_iter 是 True,那么 num_consumed_samples 一定是 0; | |||||
| assert self.during_iter is False, "Cannot call load_state_dict() when it is " \ | assert self.during_iter is False, "Cannot call load_state_dict() when it is " \ | ||||
| "during an unfinished iteration." | "during an unfinished iteration." | ||||
| @@ -1,9 +1,14 @@ | |||||
| from typing import Dict, List, Union | from typing import Dict, List, Union | ||||
| import math | import math | ||||
| import os | |||||
| import numpy as np | import numpy as np | ||||
| from fastNLP.core.log import logger | from fastNLP.core.log import logger | ||||
| from fastNLP.core.dataset import DataSet | from fastNLP.core.dataset import DataSet | ||||
| from fastNLP.envs.env import FASTNLP_DEQUE_SIZE | |||||
| from .utils import NumConsumedSamplesArray | |||||
| __all__ = [ | __all__ = [ | ||||
| 'ReproducibleSampler', | 'ReproducibleSampler', | ||||
| @@ -30,6 +35,13 @@ class ReproducibleSampler: | |||||
| raise NotImplementedError("Each specific sampler should implement its own `__iter__` method.") | raise NotImplementedError("Each specific sampler should implement its own `__iter__` method.") | ||||
| def state_dict(self): | def state_dict(self): | ||||
| """ | |||||
| 由于现在的DataLoader都存在预取数据的功能,因此请参考 RandomSampler 中 states 里面 num_consumed_samples_array 的实现 | |||||
| 正确设置该值。其思想是记录每个 index 对应的 num_consumed_samples ,在 Trainer.save 时会根据 Trainer 中的真实 forward | |||||
| 了多少个 sample 从 num_consumed_samples_array 取出对应的 num_consumed_samples 进行存储。 | |||||
| :return: | |||||
| """ | |||||
| raise NotImplementedError("Each specific sampler should implement its own `state_dict` method.") | raise NotImplementedError("Each specific sampler should implement its own `state_dict` method.") | ||||
| def load_state_dict(self, states): | def load_state_dict(self, states): | ||||
| @@ -109,12 +121,15 @@ class RandomSampler(ReproducibleSampler): | |||||
| indices = indices[self.num_consumed_samples:] | indices = indices[self.num_consumed_samples:] | ||||
| indices = indices[self.rank:len(indices):self.num_replicas] | indices = indices[self.rank:len(indices):self.num_replicas] | ||||
| assert len(indices) == self.num_left_samples | assert len(indices) == self.num_left_samples | ||||
| for index in indices: | |||||
| self.num_consumed_samples_array = NumConsumedSamplesArray(buffer_size=os.environ.get(FASTNLP_DEQUE_SIZE, 2000), | |||||
| num_consumed_samples=self.num_consumed_samples) | |||||
| for idx, index in enumerate(indices, start=1): | |||||
| self.num_consumed_samples += self.num_replicas | self.num_consumed_samples += self.num_replicas | ||||
| self.num_consumed_samples_array.push(self.num_consumed_samples) | |||||
| yield index | yield index | ||||
| self.during_iter = False | self.during_iter = False | ||||
| self.num_consumed_samples = 0 | self.num_consumed_samples = 0 | ||||
| delattr(self, 'num_consumed_samples_array') | |||||
| def generate_indices(self) -> List[int]: | def generate_indices(self) -> List[int]: | ||||
| """ | """ | ||||
| @@ -134,18 +149,13 @@ class RandomSampler(ReproducibleSampler): | |||||
| return indices | return indices | ||||
| def state_dict(self) -> Dict: | def state_dict(self) -> Dict: | ||||
| states = { | |||||
| 'seed': self.seed, | |||||
| 'epoch': self.epoch, | |||||
| 'num_consumed_samples': self.num_consumed_samples, # 注意该值是计算所有 rank 上训练的所有数据; | |||||
| 'sampler_type': self.__class__.__name__, | |||||
| 'length': len(self.dataset), | |||||
| 'shuffle': self.shuffle | |||||
| } | |||||
| states = {'seed': self.seed, 'epoch': self.epoch, 'num_consumed_samples': self.num_consumed_samples, | |||||
| 'sampler_type': self.__class__.__name__, 'length': len(self.dataset), 'shuffle': self.shuffle, | |||||
| 'num_consumed_samples_array': getattr(self, 'num_consumed_samples_array', None)} | |||||
| return states | return states | ||||
| def load_state_dict(self, states: Dict): | def load_state_dict(self, states: Dict): | ||||
| # 如果 self.during_iter 是 True,那么 data_idx 一定是 0; | |||||
| # 如果 self.during_iter 是 True,那么 num_consumed_samples 一定是 0; | |||||
| assert self.during_iter is False, "Cannot call load_state_dict() when it is " \ | assert self.during_iter is False, "Cannot call load_state_dict() when it is " \ | ||||
| "during an unfinished iteration." | "during an unfinished iteration." | ||||
| @@ -158,7 +168,7 @@ class RandomSampler(ReproducibleSampler): | |||||
| self.seed = states['seed'] | self.seed = states['seed'] | ||||
| self.epoch = states['epoch'] | self.epoch = states['epoch'] | ||||
| self.num_consumed_samples = states['num_consumed_samples'] | self.num_consumed_samples = states['num_consumed_samples'] | ||||
| if self.num_consumed_samples>=length: # 如果保存的时候已经到达了最后一个sample了,则直接将结果重置为0 | |||||
| if self.num_consumed_samples >= length: # 如果保存的时候已经到达了最后一个sample了,则直接将结果重置为0 | |||||
| self.num_consumed_samples = 0 | self.num_consumed_samples = 0 | ||||
| if self.shuffle != states['shuffle']: | if self.shuffle != states['shuffle']: | ||||
| logger.info(f"The shuffle from the checkpoint is {states['shuffle']}, while set as {self.shuffle}, " | logger.info(f"The shuffle from the checkpoint is {states['shuffle']}, while set as {self.shuffle}, " | ||||
| @@ -245,11 +255,15 @@ class SequentialSampler(RandomSampler): | |||||
| indices = indices[self.rank:len(indices):self.num_replicas] | indices = indices[self.rank:len(indices):self.num_replicas] | ||||
| assert len(indices) == self.num_left_samples | assert len(indices) == self.num_left_samples | ||||
| for index in indices: | |||||
| self.num_consumed_samples_array = NumConsumedSamplesArray(buffer_size=os.environ.get(FASTNLP_DEQUE_SIZE, 2000), | |||||
| num_consumed_samples=self.num_consumed_samples) | |||||
| for idx, index in enumerate(indices, start=1): | |||||
| self.num_consumed_samples += self.num_replicas | self.num_consumed_samples += self.num_replicas | ||||
| self.num_consumed_samples_array.push(self.num_consumed_samples) | |||||
| yield index | yield index | ||||
| self.during_iter = False | self.during_iter = False | ||||
| self.num_consumed_samples = 0 | self.num_consumed_samples = 0 | ||||
| delattr(self, 'num_consumed_samples_array') | |||||
| def generate_indices(self) -> List[int]: | def generate_indices(self) -> List[int]: | ||||
| """ | """ | ||||
| @@ -260,15 +274,13 @@ class SequentialSampler(RandomSampler): | |||||
| return list(range(len(self.dataset))) | return list(range(len(self.dataset))) | ||||
| def state_dict(self) -> Dict: | def state_dict(self) -> Dict: | ||||
| states = { | |||||
| 'num_consumed_samples': self.num_consumed_samples, # 注意该值是计算所有 rank 上训练的所有数据; | |||||
| 'sampler_type': self.__class__.__name__, | |||||
| 'length': len(self.dataset), | |||||
| } | |||||
| states = {'num_consumed_samples': self.num_consumed_samples, 'sampler_type': self.__class__.__name__, | |||||
| 'length': len(self.dataset), | |||||
| 'num_consumed_samples_array': getattr(self, 'num_consumed_samples_array', None)} | |||||
| return states | return states | ||||
| def load_state_dict(self, states: Dict): | def load_state_dict(self, states: Dict): | ||||
| # 如果 self.during_iter 是 True,那么 data_idx 一定是 0; | |||||
| # 如果 self.during_iter 是 True,那么 num_consumed_samples 一定是 0; | |||||
| assert self.during_iter is False, "Cannot call load_state_dict() when it is " \ | assert self.during_iter is False, "Cannot call load_state_dict() when it is " \ | ||||
| "during an unfinished iteration." | "during an unfinished iteration." | ||||
| @@ -334,9 +346,13 @@ class SortedSampler(SequentialSampler): | |||||
| indices = indices[self.rank:len(indices):self.num_replicas] | indices = indices[self.rank:len(indices):self.num_replicas] | ||||
| assert len(indices) == self.num_left_samples | assert len(indices) == self.num_left_samples | ||||
| for index in indices: | |||||
| self.num_consumed_samples_array = NumConsumedSamplesArray(buffer_size=os.environ.get(FASTNLP_DEQUE_SIZE, 2000), | |||||
| num_consumed_samples=self.num_consumed_samples) | |||||
| for idx, index in enumerate(indices, start=1): | |||||
| self.num_consumed_samples += self.num_replicas | self.num_consumed_samples += self.num_replicas | ||||
| self.num_consumed_samples_array.push(self.num_consumed_samples) | |||||
| yield index | yield index | ||||
| self.during_iter = False | self.during_iter = False | ||||
| self.num_consumed_samples = 0 | self.num_consumed_samples = 0 | ||||
| delattr(self, 'num_consumed_samples_array') | |||||
| @@ -2,6 +2,9 @@ __all__ = [ | |||||
| 're_instantiate_sampler', | 're_instantiate_sampler', | ||||
| 'conversion_between_reproducible_and_unrepeated_sampler' | 'conversion_between_reproducible_and_unrepeated_sampler' | ||||
| ] | ] | ||||
| from array import array | |||||
| from typing import Sequence | |||||
| from collections import deque | |||||
| from fastNLP.core.samplers.unrepeated_sampler import * | from fastNLP.core.samplers.unrepeated_sampler import * | ||||
| from fastNLP.core.samplers.reproducible_sampler import * | from fastNLP.core.samplers.reproducible_sampler import * | ||||
| @@ -39,4 +42,56 @@ def re_instantiate_sampler(sampler, new_sampler_class=None): | |||||
| all_attributes = vars(sampler) | all_attributes = vars(sampler) | ||||
| if new_sampler_class is not None: | if new_sampler_class is not None: | ||||
| return new_sampler_class(**all_attributes) | return new_sampler_class(**all_attributes) | ||||
| return type(sampler)(**all_attributes) | |||||
| return type(sampler)(**all_attributes) | |||||
| def create_array(length, fill_value) -> array: | |||||
| """ | |||||
| 根据长度自动创建 array ,超过 4294967295 需要使用 'L', 否则使用 'I' | |||||
| :param length: | |||||
| :param fill_value: | |||||
| :return: | |||||
| """ | |||||
| if not isinstance(fill_value, Sequence): | |||||
| fill_value = [fill_value]*length | |||||
| if length > 4294967295: | |||||
| _index_lst = array("L", fill_value) | |||||
| else: | |||||
| _index_lst = array("I", fill_value) | |||||
| return _index_lst | |||||
| class NumConsumedSamplesArray: | |||||
| def __init__(self, buffer_size=2000, num_consumed_samples=0): | |||||
| """ | |||||
| 保留 buffer_size 个 num_consumed_samples 数据,可以索引得到某个 index 下的 num_consumed_samples 多少 | |||||
| ex: | |||||
| array = NumConsumedSamplesArray(buffer_size=3) | |||||
| for i in range(10): | |||||
| array.push(i) | |||||
| array[9] # 输出为9,表示这个位置真实的 num_consumed_samples 是多少。 | |||||
| array[6] # 报错,因为只保留了3个最近的数据,6超过了最大buffer的记录了,即 [7, 8, 9] | |||||
| :param buffer_size: 报错多少个历史。 | |||||
| :param num_consumed_samples: 第一个 num_consumed_samples 是多少。 | |||||
| """ | |||||
| self.count = 0 | |||||
| self.deque = deque(maxlen=buffer_size) | |||||
| if num_consumed_samples is not None: | |||||
| self.push(num_consumed_samples) | |||||
| self.buffer_size = buffer_size | |||||
| def __getitem__(self, item): | |||||
| if len(self.deque) == 0: # 如果没有任何缓存的内容,说明还没有写入,直接返回0 | |||||
| return 0 | |||||
| assert isinstance(item, int), "Only int index allowed." | |||||
| assert self.count-len(self.deque)<=item<self.count, f"Only keep {len(self.deque)} history index." | |||||
| index = len(self.deque) - (self.count - item) | |||||
| return self.deque[index] | |||||
| def push(self, num_consumed_samples): | |||||
| self.deque.append(num_consumed_samples) | |||||
| self.count += 1 | |||||
| @@ -45,6 +45,8 @@ FASTNLP_REMOVE_LOCAL_RANK = 'FASTNLP_REMOVE_LOCAL_RANK' | |||||
| # todo 注释 | # todo 注释 | ||||
| FASTNLP_BACKEND_LAUNCH = "FASTNLP_BACKEND_LAUNCH" | FASTNLP_BACKEND_LAUNCH = "FASTNLP_BACKEND_LAUNCH" | ||||
| # fastNLP 中初始化deque的默认大小 | |||||
| FASTNLP_DEQUE_SIZE = 'FASTNLP_DEQUE_SIZE' | |||||
| # todo 注释 直接使用的变量 | # todo 注释 直接使用的变量 | ||||
| FASTNLP_MODEL_FILENAME = "fastnlp_model.pkl.tar" | FASTNLP_MODEL_FILENAME = "fastnlp_model.pkl.tar" | ||||
| @@ -3,6 +3,7 @@ from array import array | |||||
| import numpy as np | import numpy as np | ||||
| import pytest | import pytest | ||||
| from itertools import chain | from itertools import chain | ||||
| from copy import deepcopy | |||||
| from fastNLP.core.samplers import RandomBatchSampler, BucketedBatchSampler | from fastNLP.core.samplers import RandomBatchSampler, BucketedBatchSampler | ||||
| from fastNLP.core.drivers.torch_driver.utils import replace_batch_sampler | from fastNLP.core.drivers.torch_driver.utils import replace_batch_sampler | ||||
| @@ -30,7 +31,7 @@ class TestReproducibleBatchSampler: | |||||
| _get_re_batchsampler = dataloader.batch_sampler | _get_re_batchsampler = dataloader.batch_sampler | ||||
| assert isinstance(_get_re_batchsampler, RandomBatchSampler) | assert isinstance(_get_re_batchsampler, RandomBatchSampler) | ||||
| state = _get_re_batchsampler.state_dict() | state = _get_re_batchsampler.state_dict() | ||||
| assert state == {"index_list": array("I", list(range(100))), "data_idx": forward_steps*before_batch_size, | |||||
| assert state == {"index_list": array("I", list(range(100))), "num_consumed_samples": forward_steps*before_batch_size, | |||||
| "sampler_type": "RandomBatchSampler"} | "sampler_type": "RandomBatchSampler"} | ||||
| # 2. 断点重训,重新生成一个 dataloader; | # 2. 断点重训,重新生成一个 dataloader; | ||||
| @@ -413,26 +414,102 @@ class TestBucketedBatchSampler: | |||||
| @pytest.mark.parametrize('drop_last', [True, False]) | @pytest.mark.parametrize('drop_last', [True, False]) | ||||
| @pytest.mark.parametrize('pad', [True, False]) | @pytest.mark.parametrize('pad', [True, False]) | ||||
| @pytest.mark.parametrize('num_samples', [13, 100, 623, 1000]) | @pytest.mark.parametrize('num_samples', [13, 100, 623, 1000]) | ||||
| @pytest.mark.parametrize('num_replica', [2, 3]) | |||||
| def test_multi_same_bucket(self, shuffle, drop_last, pad, num_samples, num_replica): | |||||
| # def test_multi_same_bucket(self, shuffle=True, drop_last=True, pad=True, num_samples=623, num_replica=2): | |||||
| @pytest.mark.parametrize('num_replicas', [2, 3]) | |||||
| def test_multi_same_bucket(self, shuffle, drop_last, pad, num_samples, num_replicas): | |||||
| # def test_multi_same_bucket(self, shuffle=True, drop_last=True, pad=True, num_samples=623, num_replicas=2): | |||||
| dataset = DatasetWithVaryLength(num_of_data=num_samples) | dataset = DatasetWithVaryLength(num_of_data=num_samples) | ||||
| batch_size = 6 | batch_size = 6 | ||||
| if num_replica*batch_size > num_samples: | |||||
| if num_replicas*batch_size > num_samples: | |||||
| return | return | ||||
| num_batch_per_bucket = 10 | num_batch_per_bucket = 10 | ||||
| samplers = [] | samplers = [] | ||||
| lengths = [] | lengths = [] | ||||
| for i in range(num_replica): | |||||
| for i in range(num_replicas): | |||||
| sampler = BucketedBatchSampler(dataset, length=dataset.data, batch_size=batch_size, | sampler = BucketedBatchSampler(dataset, length=dataset.data, batch_size=batch_size, | ||||
| num_batch_per_bucket=num_batch_per_bucket, shuffle=shuffle, drop_last=drop_last) | num_batch_per_bucket=num_batch_per_bucket, shuffle=shuffle, drop_last=drop_last) | ||||
| sampler.set_distributed(num_replica, rank=i, pad=pad) | |||||
| sampler.set_distributed(num_replicas, rank=i, pad=pad) | |||||
| sampler.set_epoch(0) | sampler.set_epoch(0) | ||||
| samplers.append(sampler) | samplers.append(sampler) | ||||
| lengths.append(len(list(iter(sampler)))) | lengths.append(len(list(iter(sampler)))) | ||||
| assert len(set(lengths))==1 | assert len(set(lengths))==1 | ||||
| bucket_diff = batch_size * num_batch_per_bucket * num_replica | |||||
| bucket_diff = batch_size * num_batch_per_bucket * num_replicas | |||||
| for bs in zip(*samplers): | for bs in zip(*samplers): | ||||
| diff = max(chain(*bs)) - min(chain(*bs)) | diff = max(chain(*bs)) - min(chain(*bs)) | ||||
| assert diff <= bucket_diff | assert diff <= bucket_diff | ||||
| @pytest.mark.parametrize('shuffle', [True, False]) | |||||
| @pytest.mark.parametrize('drop_last', [True, False]) | |||||
| @pytest.mark.parametrize('pad', [True, False]) | |||||
| @pytest.mark.parametrize('num_samples', [13, 100, 623, 1000]) | |||||
| @pytest.mark.parametrize('num_replicas', [1, 2, 3]) | |||||
| def test_multi_save_load(self, shuffle, drop_last, pad, num_samples, num_replicas): | |||||
| """ | |||||
| 测试是否能够正确地恢复使用过的(forward)数据,由于 DataLoader 存在预取,所以 Sampler 自身的 num_consumed_samples 可能 | |||||
| 偏多 | |||||
| :return: | |||||
| """ | |||||
| batch_size = 6 | |||||
| num_batch_per_bucket = 10 | |||||
| dataset = DatasetWithVaryLength(num_of_data=num_samples) | |||||
| samplers = [] | |||||
| for i in range(num_replicas): | |||||
| sampler = BucketedBatchSampler(dataset, length=dataset.data, batch_size=batch_size, | |||||
| num_batch_per_bucket=num_batch_per_bucket, shuffle=shuffle, drop_last=drop_last) | |||||
| sampler.set_distributed(num_replicas=num_replicas, rank=i, pad=pad) | |||||
| samplers.append(sampler) | |||||
| count = 0 | |||||
| already_seen_sets = [set()] | |||||
| already_seen_set = set() | |||||
| for batchs in zip(*samplers): | |||||
| batch = chain(*batchs) | |||||
| already_seen_set.update(batch) | |||||
| already_seen_sets.append(deepcopy(already_seen_set)) | |||||
| count += 1 | |||||
| if count > 3: | |||||
| break | |||||
| states = samplers[0].state_dict() | |||||
| for i in range(len(already_seen_sets)): | |||||
| if states['num_consumed_samples_array'] is not None: | |||||
| states['num_consumed_samples'] = states['num_consumed_samples_array'][i] | |||||
| sampler = BucketedBatchSampler(dataset, length=dataset.data, batch_size=batch_size+1, | |||||
| num_batch_per_bucket=num_batch_per_bucket, shuffle=shuffle, | |||||
| drop_last=drop_last) | |||||
| sampler.set_epoch(0) | |||||
| already_seen_set = deepcopy(already_seen_sets[i]) | |||||
| for batch in sampler: | |||||
| already_seen_set.update(batch) | |||||
| assert len(already_seen_set) == len(dataset) if drop_last is False else len(already_seen_set) <= len( | |||||
| dataset) | |||||
| # 测试保存之后再次保存 | |||||
| sampler = BucketedBatchSampler(dataset, length=dataset.data, batch_size=batch_size + 1, | |||||
| num_batch_per_bucket=num_batch_per_bucket, shuffle=shuffle, | |||||
| drop_last=drop_last) | |||||
| sampler.set_epoch(0) | |||||
| if states['num_consumed_samples_array'] is not None: | |||||
| states['num_consumed_samples'] = states['num_consumed_samples_array'][2] | |||||
| if len(already_seen_sets)<3: | |||||
| return | |||||
| already_seen_set = already_seen_sets[2] | |||||
| count = 0 | |||||
| for batch in sampler: | |||||
| already_seen_set.update(batch) | |||||
| count += 1 | |||||
| if count > 6: | |||||
| break | |||||
| states = sampler.state_dict() | |||||
| if states['num_consumed_samples_array'] is not None: | |||||
| states['num_consumed_samples'] = states['num_consumed_samples_array'][count] | |||||
| sampler = BucketedBatchSampler(dataset, length=dataset.data, batch_size=batch_size//2, | |||||
| num_batch_per_bucket=num_batch_per_bucket, shuffle=shuffle, | |||||
| drop_last=drop_last) | |||||
| sampler.load_state_dict(states) | |||||
| sampler.set_epoch(0) | |||||
| for batch in sampler: | |||||
| already_seen_set.update(batch) | |||||
| assert len(already_seen_set)==len(dataset) if drop_last is False else len(already_seen_set)<=len(dataset) | |||||
| @@ -3,6 +3,7 @@ import pytest | |||||
| from functools import partial | from functools import partial | ||||
| from itertools import chain | from itertools import chain | ||||
| from copy import deepcopy | |||||
| from fastNLP.core.samplers.reproducible_sampler import RandomSampler, SortedSampler, SequentialSampler | from fastNLP.core.samplers.reproducible_sampler import RandomSampler, SortedSampler, SequentialSampler | ||||
| from tests.helpers.datasets.torch_data import TorchNormalDataset | from tests.helpers.datasets.torch_data import TorchNormalDataset | ||||
| @@ -180,6 +181,63 @@ class TestRandomSamplerYh: | |||||
| assert seen <= 1 if pad else seen == 0 | assert seen <= 1 if pad else seen == 0 | ||||
| assert seen_in_other_rank<=1 # 因为pad可能重复 | assert seen_in_other_rank<=1 # 因为pad可能重复 | ||||
| @pytest.mark.parametrize('shuffle', [True, False]) | |||||
| @pytest.mark.parametrize('pad', [True, False]) | |||||
| @pytest.mark.parametrize('num_samples', [13, 100, 623, 1000]) | |||||
| @pytest.mark.parametrize('num_replicas', [1, 2, 3]) | |||||
| def test_num_consumed_samples_array(self, shuffle, pad, num_samples, num_replicas): | |||||
| # 测试在 sampler 多生成的时候,可以仍然可以恢复 | |||||
| dataset = DatasetWithVaryLength(num_of_data=num_samples) | |||||
| samplers = [] | |||||
| for i in range(num_replicas): | |||||
| sampler = RandomSampler(dataset, shuffle=shuffle) | |||||
| sampler.set_epoch(0) | |||||
| sampler.set_distributed(num_replicas=num_replicas, rank=i, pad=pad) | |||||
| samplers.append(sampler) | |||||
| count = 0 | |||||
| already_seen_sets = [set()] | |||||
| already_seen_set = set() | |||||
| for idxes in zip(*samplers): | |||||
| already_seen_set.update(idxes) | |||||
| already_seen_sets.append(deepcopy(already_seen_set)) | |||||
| count += 1 | |||||
| if count > 3: | |||||
| break | |||||
| states = samplers[0].state_dict() | |||||
| for i in range(len(already_seen_sets)): | |||||
| if states['num_consumed_samples_array'] is not None: | |||||
| states['num_consumed_samples'] = states['num_consumed_samples_array'][i] | |||||
| sampler = RandomSampler(dataset, shuffle=shuffle) | |||||
| already_seen_set = deepcopy(already_seen_sets[i]) | |||||
| for batch in sampler: | |||||
| already_seen_set.add(batch) | |||||
| assert len(already_seen_set) == len(dataset) | |||||
| # 测试保存之后再次保存 | |||||
| sampler = RandomSampler(dataset, shuffle=shuffle) | |||||
| sampler.set_epoch(0) | |||||
| if states['num_consumed_samples_array'] is not None: | |||||
| states['num_consumed_samples'] = states['num_consumed_samples_array'][2] | |||||
| if len(already_seen_sets)<3: | |||||
| return | |||||
| already_seen_set = already_seen_sets[2] | |||||
| count = 0 | |||||
| for idx in sampler: | |||||
| already_seen_set.add(idx) | |||||
| count += 1 | |||||
| if count > 6: | |||||
| break | |||||
| states = sampler.state_dict() | |||||
| if states['num_consumed_samples_array'] is not None: | |||||
| states['num_consumed_samples'] = states['num_consumed_samples_array'][count] | |||||
| sampler = RandomSampler(dataset, shuffle=shuffle) | |||||
| sampler.load_state_dict(states) | |||||
| sampler.set_epoch(0) | |||||
| for idx in sampler: | |||||
| already_seen_set.add(idx) | |||||
| assert len(already_seen_set)==len(dataset) | |||||
| class TestRandomSampler: | class TestRandomSampler: | ||||
| # 测试单卡; | # 测试单卡; | ||||
| @@ -386,7 +444,7 @@ class TestSortedSampler: | |||||
| assert indexes==list(range(num_of_data-1, -1, -1)) | assert indexes==list(range(num_of_data-1, -1, -1)) | ||||
| @pytest.mark.parametrize('pad', [True, False]) | @pytest.mark.parametrize('pad', [True, False]) | ||||
| @pytest.mark.parametrize('num_replica', [2, 3]) | |||||
| @pytest.mark.parametrize('num_replicas', [2, 3]) | |||||
| @pytest.mark.parametrize('num_of_data', [2, 3, 4, 100]) | @pytest.mark.parametrize('num_of_data', [2, 3, 4, 100]) | ||||
| def test_multi(self, pad, num_replica, num_of_data): | def test_multi(self, pad, num_replica, num_of_data): | ||||
| data = DatasetWithVaryLength(num_of_data=num_of_data) | data = DatasetWithVaryLength(num_of_data=num_of_data) | ||||
| @@ -540,7 +598,7 @@ class TestSequentialSampler: | |||||
| assert indexes==list(range(num_of_data)) | assert indexes==list(range(num_of_data)) | ||||
| @pytest.mark.parametrize('pad', [True, False]) | @pytest.mark.parametrize('pad', [True, False]) | ||||
| @pytest.mark.parametrize('num_replica', [2, 3]) | |||||
| @pytest.mark.parametrize('num_replicas', [2, 3]) | |||||
| @pytest.mark.parametrize('num_of_data', [2, 3, 4, 100]) | @pytest.mark.parametrize('num_of_data', [2, 3, 4, 100]) | ||||
| def test_multi(self, pad, num_replica, num_of_data): | def test_multi(self, pad, num_replica, num_of_data): | ||||
| data = DatasetWithVaryLength(num_of_data=num_of_data) | data = DatasetWithVaryLength(num_of_data=num_of_data) | ||||
| @@ -25,7 +25,7 @@ class TestUnrepeatedSampler: | |||||
| indexes = set(sampler) | indexes = set(sampler) | ||||
| assert indexes==set(range(num_of_data)) | assert indexes==set(range(num_of_data)) | ||||
| @pytest.mark.parametrize('num_replica', [2, 3]) | |||||
| @pytest.mark.parametrize('num_replicas', [2, 3]) | |||||
| @pytest.mark.parametrize('num_of_data', [2, 3, 4, 100]) | @pytest.mark.parametrize('num_of_data', [2, 3, 4, 100]) | ||||
| @pytest.mark.parametrize('shuffle', [False, True]) | @pytest.mark.parametrize('shuffle', [False, True]) | ||||
| def test_multi(self, num_replica, num_of_data, shuffle): | def test_multi(self, num_replica, num_of_data, shuffle): | ||||
| @@ -50,7 +50,7 @@ class TestUnrepeatedSortedSampler: | |||||
| indexes = list(sampler) | indexes = list(sampler) | ||||
| assert indexes==list(range(num_of_data-1, -1, -1)) | assert indexes==list(range(num_of_data-1, -1, -1)) | ||||
| @pytest.mark.parametrize('num_replica', [2, 3]) | |||||
| @pytest.mark.parametrize('num_replicas', [2, 3]) | |||||
| @pytest.mark.parametrize('num_of_data', [2, 3, 4, 100]) | @pytest.mark.parametrize('num_of_data', [2, 3, 4, 100]) | ||||
| def test_multi(self, num_replica, num_of_data): | def test_multi(self, num_replica, num_of_data): | ||||
| data = DatasetWithVaryLength(num_of_data=num_of_data) | data = DatasetWithVaryLength(num_of_data=num_of_data) | ||||
| @@ -81,7 +81,7 @@ class TestUnrepeatedSequentialSampler: | |||||
| indexes = list(sampler) | indexes = list(sampler) | ||||
| assert indexes==list(range(num_of_data)) | assert indexes==list(range(num_of_data)) | ||||
| @pytest.mark.parametrize('num_replica', [2, 3]) | |||||
| @pytest.mark.parametrize('num_replicas', [2, 3]) | |||||
| @pytest.mark.parametrize('num_of_data', [2, 3, 4, 100]) | @pytest.mark.parametrize('num_of_data', [2, 3, 4, 100]) | ||||
| def test_multi(self, num_replica, num_of_data): | def test_multi(self, num_replica, num_of_data): | ||||
| data = DatasetWithVaryLength(num_of_data=num_of_data) | data = DatasetWithVaryLength(num_of_data=num_of_data) | ||||