| @@ -397,12 +397,13 @@ def fastnlp_torch_all_gather(obj:Any, device=None, group=None)->List: | |||
| """ | |||
| # # 首先将所有的都移动到cpu上并且连续,防止有 pickle 出问题 | |||
| # obj = apply_to_collection(obj, torch.Tensor, _to_device, device=torch.device('cpu')) | |||
| if device is None: | |||
| device = torch.cuda.current_device() | |||
| if _TORCH_GREATER_EQUAL_1_8: | |||
| objs = [None for _ in range(dist.get_world_size(group))] | |||
| dist.all_gather_object(objs, obj) | |||
| apply_to_collection(obj, torch.Tensor, _to_device, device=device) # 保证如果有tensor的话,所有tensor都在当前卡上 | |||
| return objs | |||
| if device is None: | |||
| device = torch.cuda.current_device() | |||
| group = group if group is not None else torch.distributed.group.WORLD | |||
| data = convert_to_tensors(obj, device=device) | |||
| data = apply_to_collection(data, (torch.Tensor, tuple), _all_gather, group=group) | |||
| @@ -130,8 +130,8 @@ class TorchSingleDriver(TorchDriver): | |||
| else: | |||
| return self._test_step(batch) | |||
| def set_dist_repro_dataloader(self, dataloader, dist: Union[str, ReproducibleBatchSampler, ReproducibleIterator], | |||
| reproducible: bool = False, sampler_or_batch_sampler=None): | |||
| def set_dist_repro_dataloader(self, dataloader, dist: Union[str, ReproducibleBatchSampler, ReproducibleIterator]=None, | |||
| reproducible: bool = False): | |||
| if isinstance(dist, ReproducibleBatchSampler): | |||
| return replace_batch_sampler(dataloader, dist) | |||
| elif isinstance(dist, ReproducibleIterator): | |||
| @@ -34,6 +34,7 @@ class TorchBackend(Backend): | |||
| if method is None: | |||
| raise AggregateMethodError(should_have_aggregate_method=True) | |||
| tensor = self._gather_all(tensor) | |||
| # tensor = self.all_gather_object(tensor) | |||
| if isinstance(tensor[0], torch.Tensor): | |||
| tensor = torch.stack(tensor) | |||
| # 第一步, aggregate结果 | |||
| @@ -34,6 +34,7 @@ class Element: | |||
| 自动aggregate对应的元素 | |||
| """ | |||
| self._check_value_initialized() | |||
| try: | |||
| self._value = self.backend.aggregate(self._value, self.aggregate_method) | |||
| except AggregateMethodError as e: | |||
| @@ -216,9 +216,9 @@ def _compute_f_pre_rec(beta_square, tp, fn, fp): | |||
| class SpanFPreRecMetric(Metric): | |||
| def __init__(self, backend: Union[str, Backend, None] = 'auto', tag_vocab: Vocabulary = None, | |||
| encoding_type: str = None, ignore_labels: List[str] = None, only_gross: bool = True, f_type='micro', | |||
| beta=1, aggregate_when_get_metric: bool = True,) -> None: | |||
| def __init__(self, tag_vocab: Vocabulary, encoding_type: str = None, ignore_labels: List[str] = None, | |||
| only_gross: bool = True, f_type='micro', | |||
| beta=1, backend: Union[str, Backend, None] = 'auto', aggregate_when_get_metric: bool = True,) -> None: | |||
| super(SpanFPreRecMetric, self).__init__(backend=backend, aggregate_when_get_metric=aggregate_when_get_metric) | |||
| if f_type not in ('micro', 'macro'): | |||
| raise ValueError("f_type only supports `micro` or `macro`', got {}.".format(f_type)) | |||
| @@ -249,9 +249,18 @@ class SpanFPreRecMetric(Metric): | |||
| self.only_gross = only_gross | |||
| self.tag_vocab = tag_vocab | |||
| self._true_positives = defaultdict(partial(self.register_element, aggregate_method='sum', name=None)) | |||
| self._false_positives = defaultdict(partial(self.register_element, aggregate_method='sum', name=None)) | |||
| self._false_negatives = defaultdict(partial(self.register_element, aggregate_method='sum', name=None)) | |||
| self._true_positives = {} | |||
| self._false_positives = {} | |||
| self._false_negatives = {} | |||
| for word, _ in tag_vocab: | |||
| word = word.lower() | |||
| if word != 'o': | |||
| word = word.split('-')[1] | |||
| if word in self._true_positives: | |||
| continue | |||
| self._true_positives[word] = self.register_element(name=f'tp_{word}', aggregate_method='sum', backend=backend) | |||
| self._false_negatives[word] = self.register_element(name=f'fn_{word}', aggregate_method='sum', backend=backend) | |||
| self._false_positives[word] = self.register_element(name=f'fp_{word}', aggregate_method='sum', backend=backend) | |||
| def get_metric(self) -> dict: | |||
| evaluate_result = {} | |||
| @@ -284,10 +293,17 @@ class SpanFPreRecMetric(Metric): | |||
| evaluate_result['rec'] = rec_sum / len(tags) | |||
| if self.f_type == 'micro': | |||
| tp, fn, fp = [], [], [] | |||
| for val in self._true_positives.values(): | |||
| tp.append(val.get_scalar()) | |||
| for val in self._false_negatives.values(): | |||
| fn.append(val.get_scalar()) | |||
| for val in self._false_positives.values(): | |||
| fp.append(val.get_scalar()) | |||
| f, pre, rec = _compute_f_pre_rec(self.beta_square, | |||
| sum(val.get_scalar() for val in self._true_positives.values()), | |||
| sum(val.get_scalar() for val in self._false_negatives.values()), | |||
| sum(val.get_scalar() for val in self._false_positives.values())) | |||
| sum(tp), | |||
| sum(fn), | |||
| sum(fp)) | |||
| evaluate_result['f'] = f | |||
| evaluate_result['pre'] = pre | |||
| evaluate_result['rec'] = rec | |||
| @@ -11,11 +11,11 @@ __all__ = [ | |||
| 'PollingSampler', | |||
| 'ReproducibleIterator', | |||
| 'RandomSampler', | |||
| 'ReproducibleBatchSampler', | |||
| 're_instantiate_sampler' | |||
| ] | |||
| from .sampler import BucketSampler, SortedSampler, ConstTokenNumSampler, ConstantTokenNumSampler, UnrepeatedDistributedSampler | |||
| from .mix_sampler import MixSampler, InnerSampler, DopedSampler, MixSequentialSampler, PollingSampler | |||
| from .reproducible_sampler import ReproducibleIterator, RandomSampler, ReproducibleBatchSampler, re_instantiate_sampler | |||
| from .reproducible_sampler import ReproducibleIterator, RandomSampler, re_instantiate_sampler | |||
| from .reproducible_batch_sampler import ReproducibleBatchSampler, BucketedBatchSampler | |||
| @@ -0,0 +1,397 @@ | |||
| __all__ = [ | |||
| 'BucketedBatchSampler', | |||
| "ReproducibleBatchSampler" | |||
| ] | |||
| import math | |||
| from array import array | |||
| from copy import deepcopy | |||
| from typing import Dict, Union, List | |||
| from itertools import chain | |||
| import numpy as np | |||
| from fastNLP.core.dataset import DataSet | |||
| from fastNLP.core.log import logger | |||
| from abc import abstractmethod | |||
| class ReproducibleBatchIterator: | |||
| @abstractmethod | |||
| def set_distributed(self, num_replicas, rank, pad=True): | |||
| raise NotImplementedError("Each specific batch_sampler should implement its own `set_distributed` method.") | |||
| @abstractmethod | |||
| def __len__(self): | |||
| raise NotImplementedError("Each specific batch_sampler should implement its own `__len__` method.") | |||
| @abstractmethod | |||
| def __iter__(self): | |||
| raise NotImplementedError("Each specific batch_sampler should implement its own `__iter__` method.") | |||
| @abstractmethod | |||
| def state_dict(self): | |||
| raise NotImplementedError("Each specific batch_sampler should implement its own `state_dict` method.") | |||
| @abstractmethod | |||
| def load_state_dict(self, states): | |||
| raise NotImplementedError("Each specific batch_sampler should implement its own `load_state_dict` method.") | |||
| @abstractmethod | |||
| def set_epoch(self, epoch): | |||
| pass | |||
| class ReproducibleBatchSampler(ReproducibleBatchIterator): | |||
| # 这两个参数的值应当交给 driver 的 get_dataloader_args 函数去拿; | |||
| def __init__(self, batch_sampler, batch_size: int, drop_last: bool, **kwargs): | |||
| """ | |||
| 可以使得 batch_sampler 对象状态恢复的 wrapper 。 | |||
| :param batch_sampler: 可迭代出 数字 或 数字列表 的可迭代对象。ReproducibleBatchSampler 将首先遍历一边该对象,然后将迭代 | |||
| 出来的序号暂存起来,使用时按照 batch_size 的 batch 大小吐出序号列表。 | |||
| :param batch_size: 每个 batch 的大小是多少。 | |||
| :param drop_last: 如果最后一个 batch 无法构成 batch_size 那么多个 sample ,是否丢掉。 | |||
| :param kwargs: fastNLP 内部使用。 | |||
| """ | |||
| self.batch_sampler = batch_sampler | |||
| self.batch_size = batch_size | |||
| self.drop_last = drop_last | |||
| self.data_idx = kwargs.get("data_idx", 0) | |||
| self.index_list = kwargs.get("index_list", self._iterate_sampler()) | |||
| self.need_reinitialize = kwargs.get("need_reinitialize", False) | |||
| def _iterate_sampler(self): | |||
| _index_lst = [] | |||
| for idx in self.batch_sampler: | |||
| if isinstance(idx, list): | |||
| _index_lst.extend(idx) | |||
| # 说明是在初始化时传入的是一个 sampler,理论上对应于 dataloader 在初始化时没有 batch_size,也没有 batch_sampler 的情况; | |||
| else: | |||
| _index_lst.append(idx) | |||
| # 64 位机器的 unsigned int 为 4 个字节,能表示的最大大小为 4294967295; | |||
| if len(_index_lst) > 4294967295: | |||
| # 注意 self.index_list 内存放的是全部数据的 index; | |||
| # unsigned long | |||
| _index_lst = array("L", _index_lst) | |||
| else: | |||
| # unsigned int | |||
| _index_lst = array("I", _index_lst) | |||
| return _index_lst | |||
| def __iter__(self): | |||
| if self.need_reinitialize: | |||
| self.index_list = self._iterate_sampler() | |||
| self.data_idx = 0 | |||
| else: | |||
| self.need_reinitialize = True | |||
| batch = [] | |||
| if self.data_idx: | |||
| index_list = self.index_list[self.data_idx:] | |||
| else: | |||
| index_list = self.index_list | |||
| for idx in index_list: | |||
| batch.append(idx) | |||
| self.data_idx += 1 | |||
| if len(batch) == self.batch_size: | |||
| yield batch | |||
| batch = [] | |||
| if len(batch) > 0 and not self.drop_last: | |||
| yield batch | |||
| def __len__(self) -> int: | |||
| if self.drop_last: | |||
| return len(self.index_list) // self.batch_size | |||
| else: | |||
| return (len(self.index_list) + self.batch_size - 1) // self.batch_size | |||
| def state_dict(self) -> Dict: | |||
| return {"index_list": deepcopy(self.index_list), "data_idx": self.data_idx, 'sampler_type': self.__class__.__name__} | |||
| def load_state_dict(self, states: Dict): | |||
| assert states['sampler_type'] == self.__class__.__name__, f"The sampler type in checkpoint is {states['sampler_type']}," \ | |||
| f"we cannot use {self.__class__.__name__} to load it." | |||
| _index_list = states["index_list"] | |||
| assert len(_index_list) == len(self.index_list), "The number of samples is different between the checkpoint " \ | |||
| "record and current dataset." | |||
| self.index_list = _index_list | |||
| self.data_idx = states["data_idx"] | |||
| self.need_reinitialize = False | |||
| def set_distributed(self, num_replicas, rank, pad=True): | |||
| raise RuntimeError(f"ReproduceBatchSampler does not support to change to distributed training.") | |||
| def set_epoch(self, epoch): | |||
| if hasattr(self.batch_sampler, "sampler") and hasattr(self.batch_sampler.sampler, 'set_epoch') and callable(self.batch_sampler.sampler.set_epoch): | |||
| self.batch_sampler.sampler.set_epoch(epoch) | |||
| @property | |||
| def batch_idx_in_epoch(self): | |||
| if self.drop_last: | |||
| return len(self.index_list) // self.batch_size - (len(self.index_list) - self.data_idx) // self.batch_size | |||
| else: | |||
| return (len(self.index_list) + self.batch_size - 1) // self.batch_size - \ | |||
| (len(self.index_list) - self.data_idx + self.batch_size - 1) // self.batch_size | |||
| class BucketedBatchSampler(ReproducibleBatchIterator): | |||
| def __init__(self, dataset, length: Union[List[int], str], batch_size:int = 32, num_batch_per_bucket:int = 10, | |||
| shuffle: bool = True, drop_last: bool = False, seed: int = 0, **kwargs): | |||
| """ | |||
| 首先按照 sample 的长度排序,然后按照 batch_size*num_batch_per_bucket 为一个桶的大小,sample 只会在这个桶内进行组合,这样 | |||
| 每个 batch 中的 padding 数量会比较少 (因为桶内的数据的长度都接近)。 | |||
| :param dataset: 实现了 __len__ 方法的数据容器。 | |||
| :param length: 如果为 List,应当与 dataset 有一样的长度,表示 dataset 中每个元素的数量;仅当传入的 dataset 为 fastNLP 的 | |||
| DataSet 时支持传入 str,会将该str理解为 dataset 的 field 名称,若 field 中的元素为 int,则认为该值是 sample 的长度。 | |||
| 如果否则使用 len() 函数得到每个 sample 中这个 field 的长度。 | |||
| :param batch_size: 每个 batch 的大小 | |||
| :param num_batch_per_bucket: 多少个 batch 组成一个桶,数据只会在一个桶内进行 shuffle 。 | |||
| :param shuffle: 如果为 True,将不进行 shuffle,实际上数据会以从长到短的方式输出。 | |||
| :param drop_last: 如果最后一个 batch 的 sample 数量无法凑齐 batch_size 这么多,是否需要丢掉。 | |||
| :param seed: 设置的随机数种子 | |||
| :param kwargs: fastNLP 保留使用 | |||
| """ | |||
| super().__init__() | |||
| if isinstance(dataset, DataSet): | |||
| length = dataset.get_field(length) | |||
| if not isinstance(length[0], int): | |||
| length = list(map(len, length)) | |||
| else: | |||
| assert len(length) == len(dataset), "When the dataset is not fastNLP.DataSet, " \ | |||
| "the length parameter can only be List[int]" | |||
| assert len(length) == len(dataset), "The length of `data` and `length` should be equal." | |||
| self.dataset = dataset | |||
| self.length = np.array(length, dtype=int) # 按照长到短排列的序号。 | |||
| self.sorted_indices = np.argsort(self.length)[::-1] # 按长度从高到低排序的 | |||
| self.batch_size = batch_size | |||
| self.num_batch_per_bucket = num_batch_per_bucket | |||
| self.shuffle = shuffle | |||
| self.drop_last = drop_last | |||
| self.seed = seed | |||
| self.num_consumed_samples = kwargs.get("num_consumed_samples", 0) # 总共迭代了多少数据了,包括多卡情况下的其它卡上的输出的数量 | |||
| # 多卡的相关的参数 | |||
| self.num_replicas = kwargs.get("num_replicas", 1) | |||
| self.rank = kwargs.get("rank", 0) | |||
| self.epoch = kwargs.get("epoch", -1) | |||
| self.pad = kwargs.get("pad", False) # 该参数在单卡上不具有任何意义; | |||
| # 是否处于iteration之间,为True不允许调用 set_distributed()和load_state_dict() | |||
| self.during_iter = kwargs.get("during_iter", False) | |||
| # 以下变量为内部使用恢复状态的变量。 | |||
| self.old_batch_size = kwargs.get('old_batch_size', self.batch_size) | |||
| self.old_num_batch_per_bucket = kwargs.get('old_num_batch_per_bucket', self.num_batch_per_bucket) | |||
| def set_distributed(self, num_replicas, rank, pad=True): | |||
| assert self.during_iter is False, "Cannot set the sampler to be distributed when it is " \ | |||
| "during an unfinished iteration." | |||
| assert num_replicas > 0 and isinstance(num_replicas, int) | |||
| assert isinstance(rank, int) and 0 <= rank < num_replicas | |||
| # 注意初始化该函数时,所有的状态都应当默认是一个 epoch 刚开始训练的状态; | |||
| self.num_replicas = num_replicas | |||
| self.rank = rank | |||
| self.pad = pad | |||
| num_samples = (len(self.dataset)+self.num_replicas-1)//self.num_replicas*self.num_replicas if pad \ | |||
| else len(self.dataset) | |||
| if self.drop_last: | |||
| assert self.num_replicas*self.batch_size<=num_samples, "The number of samples should be greater " \ | |||
| "than the number of replicates multiplied " \ | |||
| "with batch_size when drop_last=True." | |||
| return self | |||
| @property | |||
| def total_size(self): | |||
| """ | |||
| 这个变量代表的含义是当前这个sampler会最终产生出的index数量(包括了其它rank的),因为replica和pad的原因,这个值可能等于、 | |||
| 大于或者小于len(dataset) | |||
| :return: | |||
| """ | |||
| return self.num_consumed_samples + self.num_replicas*self.num_left_samples | |||
| @property | |||
| def num_left_samples(self): | |||
| """ | |||
| 返回当前 iteration 还有多少个 sample 结束,表示的是当前 rank 的还剩多少。 | |||
| :return: | |||
| """ | |||
| num_consumed_samples = self.num_consumed_samples | |||
| return math.ceil((len(self.dataset) - num_consumed_samples) / self.num_replicas) if \ | |||
| self.pad else math.floor(((len(self.dataset) - num_consumed_samples) / self.num_replicas)) | |||
| def __len__(self): | |||
| """ | |||
| 返回当前 sampler 还会返回多少个 batch 的数据 | |||
| :return: | |||
| """ | |||
| num_sampler_per_rank = self.total_size//self.num_replicas | |||
| num_batches = num_sampler_per_rank//self.batch_size if self.drop_last else \ | |||
| (num_sampler_per_rank+self.batch_size-1)//self.batch_size | |||
| return num_batches | |||
| def __iter__(self): | |||
| if self.during_iter: # 如果发现_during_iter为True,说明之前的还没结束,只有强制重新初始化了 | |||
| self.num_consumed_samples = 0 | |||
| self.during_iter = True | |||
| sorted_indices = deepcopy(self.sorted_indices).tolist() # 按长度从高到低排序的 | |||
| if self.shuffle: | |||
| if self.num_consumed_samples > 0: # 需要先按照原来的排序,删掉多余的 | |||
| _batches = [] | |||
| for _i in range(self.old_num_replicas): | |||
| _sorted_indices = sorted_indices[_i:len(sorted_indices):self.old_num_replicas] | |||
| __batches = self.bucketerize(_sorted_indices, self.old_batch_size, self.old_num_batch_per_bucket, | |||
| seed=self.seed+self.epoch) | |||
| _batches.append(__batches) | |||
| batches = list(chain(*[_ for _ in zip(*_batches)])) | |||
| sorted_indices = list(chain(*batches)) | |||
| sorted_indices = sorted_indices[self.num_consumed_samples:] | |||
| # 再进行排序 | |||
| sub_length = self.length[sorted_indices] | |||
| sorted_indices = np.array(sorted_indices)[np.argsort(sub_length)[::-1]] # 按长度从高到低排序的 | |||
| # 取出这个 rank , | |||
| sorted_indices = sorted_indices[self.rank:len(sorted_indices):self.num_replicas] | |||
| batches = self.bucketerize(sorted_indices, self.batch_size, self.num_batch_per_bucket, | |||
| seed=self.seed+self.epoch) | |||
| batches = list(map(list, batches)) | |||
| else: | |||
| sorted_indices = sorted_indices[self.num_consumed_samples:] | |||
| sorted_indices = sorted_indices[self.rank:len(sorted_indices):self.num_replicas] | |||
| _num_batches = len(sorted_indices) // self.batch_size | |||
| if _num_batches == 0: | |||
| batches = [sorted_indices] | |||
| else: | |||
| batches = list(map(list, np.array_split(sorted_indices[:_num_batches*self.batch_size], _num_batches))) | |||
| if len(sorted_indices)%self.batch_size!=0: | |||
| batches.append(sorted_indices[_num_batches*self.batch_size:]) | |||
| need_pad_num = (len(self.dataset)-self.num_consumed_samples) % self.num_replicas | |||
| if self.pad and need_pad_num !=0 and need_pad_num<=self.rank: | |||
| if len(batches) > 0: | |||
| if len(batches[-1])<self.batch_size: | |||
| batches[-1].append(batches[-1][0]) # 这里可以保证这个bucket的长度没被破坏。 | |||
| else: | |||
| batches.append([batches[-1][0]]) | |||
| elif self.pad is False and need_pad_num !=0 and need_pad_num>self.rank: | |||
| if len(batches): | |||
| batches[-1].pop(-1) | |||
| if len(batches[-1])==0: | |||
| batches.pop(-1) | |||
| assert len(list(chain(*batches))) == self.num_left_samples | |||
| if self.drop_last and len(batches) >= 1 and len(batches[-1]) < self.batch_size: | |||
| batches = batches[:-1] | |||
| for batch in batches: | |||
| self.num_consumed_samples += self.num_replicas * len(batch) | |||
| yield list(map(int, batch)) | |||
| self.during_iter = False | |||
| self.num_consumed_samples = 0 | |||
| self.old_batch_size = self.batch_size | |||
| self.old_num_batch_per_bucket = self.num_batch_per_bucket | |||
| self.old_num_replicas = self.num_replicas | |||
| if self.epoch < 0: # 防止用户没有修改epoch,导致每个epoch都一样了 | |||
| self.epoch -= 1 | |||
| def bucketerize(self, sorted_indices, batch_size, num_batch_per_bucket, seed): | |||
| """ | |||
| 将 indices 分桶 | |||
| :param sorted_indices: List[int] | |||
| :param batch_size: int | |||
| :param num_batch_per_bucket: int | |||
| :param seed: int | |||
| :return: List[List[int]] | |||
| """ | |||
| # 实际的 bucket 大小 | |||
| bucket_size = min(len(sorted_indices), batch_size * num_batch_per_bucket) | |||
| rng = np.random.default_rng(abs(seed)) | |||
| num_buckets = (len(sorted_indices) + bucket_size - 1) // bucket_size | |||
| batches = [] | |||
| batch_indices = [] | |||
| for i in range(num_buckets): | |||
| bucket = sorted_indices[i * bucket_size:(i + 1) * bucket_size] | |||
| rng.shuffle(bucket) # bucket 内部 shuffle 一下 | |||
| _num_batches = len(bucket) // batch_size | |||
| if _num_batches == 0: | |||
| _batches = [bucket] | |||
| else: | |||
| _batches = np.array_split(bucket[:_num_batches*batch_size], _num_batches) | |||
| if len(bucket) % batch_size != 0: | |||
| _batches.append(bucket[_num_batches*batch_size:]) | |||
| batch_indices.extend(list(range(len(batches), len(batches) + len(_batches)))) | |||
| batches.extend(_batches) | |||
| last_batches = [] | |||
| # 最后一个batch 统一不参与shuffle,因为有的rank最后一个 batch 可能不足一个batch_size (不足的时候 | |||
| # 一定要放在末尾,所以就干脆所有的rank都不对最后一个batch进行shuffle)。 | |||
| if len(batches) >= 1: | |||
| last_batches = [list(batches[-1])] | |||
| batch_indices = list(batch_indices[:-1]) | |||
| rng = np.random.default_rng(abs(seed)) # 这里防止由于bucket长度不同,对随机数状态有影响 | |||
| rng.shuffle(batch_indices) # 不同的 batch 也 shuffle ,当前这种可以保证每张卡上每个 batch 长度都接近的。 | |||
| batches = (np.array(batches)[batch_indices]).tolist() | |||
| if last_batches: | |||
| batches = batches + last_batches | |||
| return batches | |||
| def state_dict(self) -> Dict: | |||
| if self.old_batch_size != self.batch_size or self.old_num_batch_per_bucket != self.num_batch_per_bucket: | |||
| raise RuntimeError("BucketedBatchSampler does not support saving before last checkpoint states have been" | |||
| " consumed. ") | |||
| states = { | |||
| 'seed': self.seed, | |||
| 'epoch': self.epoch, | |||
| 'num_consumed_samples': self.num_consumed_samples, # 注意该值是计算所有 rank 上训练的所有数据; | |||
| 'sampler_type': self.__class__.__name__, | |||
| 'length': len(self.dataset), | |||
| 'shuffle': self.shuffle, | |||
| 'batch_size': self.batch_size, | |||
| 'num_batch_per_bucket': self.num_batch_per_bucket, | |||
| 'num_replicas': self.num_replicas | |||
| } | |||
| return states | |||
| def load_state_dict(self, states: Dict): | |||
| # 如果 self.during_iter 是 True,那么 data_idx 一定是 0; | |||
| assert self.during_iter is False, "Cannot call load_state_dict() when it is " \ | |||
| "during an unfinished iteration." | |||
| assert states['sampler_type'] == self.__class__.__name__, f"The sampler type in checkpoint is {states['sampler_type']}," \ | |||
| f"we cannot use {self.__class__.__name__} to load it." | |||
| length = states['length'] | |||
| assert length == len(self.dataset), "The number of samples is different between the checkpoint record " \ | |||
| "and current dataset." | |||
| self.seed = states['seed'] | |||
| self.epoch = states['epoch'] | |||
| self.num_consumed_samples = states['num_consumed_samples'] | |||
| if self.num_consumed_samples>=length: # 如果保存的时候已经到达了最后一个sample了,则直接将结果重置为0 | |||
| self.num_consumed_samples = 0 | |||
| if self.shuffle != states['shuffle']: | |||
| logger.info(f"The shuffle from the checkpoint is {states['shuffle']}, while set as {self.shuffle}, " | |||
| f"we use shuffle={states['shuffle']}") | |||
| self.shuffle = states["shuffle"] | |||
| self.old_batch_size = states['batch_size'] | |||
| self.old_num_batch_per_bucket = states['num_batch_per_bucket'] | |||
| self.old_num_replicas = states['num_replicas'] | |||
| def set_epoch(self, epoch): | |||
| self.epoch = epoch | |||
| @@ -1,14 +1,12 @@ | |||
| from typing import Dict, List | |||
| import math | |||
| import numpy as np | |||
| from array import array | |||
| from copy import deepcopy | |||
| from fastNLP.core.log import logger | |||
| __all__ = [ | |||
| 'ReproducibleIterator', | |||
| 'RandomSampler', | |||
| 'ReproducibleBatchSampler', | |||
| 're_instantiate_sampler' | |||
| ] | |||
| @@ -22,7 +20,8 @@ def re_instantiate_sampler(sampler): | |||
| class ReproducibleIterator: | |||
| """ | |||
| 注意所有继承 `ReproducibleIterator` 的类的 `__init__` 方法中都需要加入参数 `**kwargs`,用来使我们再断点重训时重新实例化这个 sampler | |||
| 或者 batch_sampler; | |||
| 或者 batch_sampler;注意,所有在 init 中初始化的变量,都不能含有 _ 下横线作为开头;所有不在 init 中设置的变量都必须以下横线开头。 | |||
| """ | |||
| def set_distributed(self, num_replicas, rank, pad=True): | |||
| @@ -72,7 +71,7 @@ class RandomSampler(ReproducibleIterator): | |||
| self.pad = kwargs.get("pad", False) # 该参数在单卡上不具有任何意义; | |||
| # 是否处于iteration之间,为True不允许调用 set_distributed()和load_state_dict() | |||
| self._during_iter = kwargs.get("_during_iter", False) | |||
| self.during_iter = kwargs.get("during_iter", False) | |||
| def __len__(self): | |||
| """ | |||
| @@ -92,9 +91,9 @@ class RandomSampler(ReproducibleIterator): | |||
| >>> next(iter2) # 当前num_consumed_samples的数量会发生变化 | |||
| """ | |||
| if self._during_iter: # 如果发现_during_iter为True,说明之前的还没结束,只有强制重新初始化了 | |||
| if self.during_iter: # 如果发现_during_iter为True,说明之前的还没结束,只有强制重新初始化了 | |||
| self.num_consumed_samples = 0 | |||
| self._during_iter = True | |||
| self.during_iter = True | |||
| indices = self.generate_indices() | |||
| if self.pad: | |||
| @@ -118,7 +117,7 @@ class RandomSampler(ReproducibleIterator): | |||
| for index in indices: | |||
| self.num_consumed_samples += self.num_replicas | |||
| yield index | |||
| self._during_iter = False | |||
| self.during_iter = False | |||
| self.num_consumed_samples = 0 | |||
| def generate_indices(self) -> List[int]: | |||
| @@ -150,8 +149,8 @@ class RandomSampler(ReproducibleIterator): | |||
| return states | |||
| def load_state_dict(self, states: Dict): | |||
| # 如果 self._during_iter 是 True,那么 data_idx 一定是 0; | |||
| assert self._during_iter is False, "Cannot call load_state_dict() when it is " \ | |||
| # 如果 self.during_iter 是 True,那么 data_idx 一定是 0; | |||
| assert self.during_iter is False, "Cannot call load_state_dict() when it is " \ | |||
| "during an unfinished iteration." | |||
| assert states['sampler_type'] == self.__class__.__name__, f"The sampler type in checkpoint is {states['sampler_type']}," \ | |||
| @@ -165,6 +164,9 @@ class RandomSampler(ReproducibleIterator): | |||
| self.num_consumed_samples = states['num_consumed_samples'] | |||
| if self.num_consumed_samples>=length: # 如果保存的时候已经到达了最后一个sample了,则直接将结果重置为0 | |||
| self.num_consumed_samples = 0 | |||
| if self.shuffle != states['shuffle']: | |||
| logger.info(f"The shuffle from the checkpoint is {states['shuffle']}, while set as {self.shuffle}, " | |||
| f"we use shuffle={states['shuffle']}") | |||
| self.shuffle = states["shuffle"] | |||
| def set_epoch(self, epoch: int) -> None: | |||
| @@ -181,7 +183,7 @@ class RandomSampler(ReproducibleIterator): | |||
| :return: | |||
| """ | |||
| assert self._during_iter is False, "Cannot set the sampler to be distributed when it is " \ | |||
| assert self.during_iter is False, "Cannot set the sampler to be distributed when it is " \ | |||
| "during an unfinished iteration." | |||
| assert num_replicas>0 and isinstance(num_replicas, int) | |||
| assert isinstance(rank, int) and 0<=rank<num_replicas | |||
| @@ -204,7 +206,7 @@ class RandomSampler(ReproducibleIterator): | |||
| @property | |||
| def num_left_samples(self): | |||
| """ | |||
| 返回当前 iteration 还有多少个 sample 结束 | |||
| 返回当前 iteration 还有多少个 sample 结束。表示的是当前 rank 的还剩多少 | |||
| :return: | |||
| """ | |||
| @@ -213,119 +215,8 @@ class RandomSampler(ReproducibleIterator): | |||
| self.pad else math.floor(((len(self.dataset) - num_consumed_samples) / self.num_replicas)) | |||
| class ReproducibleBatchSampler: | |||
| # 这两个参数的值应当交给 driver 的 get_dataloader_args 函数去拿; | |||
| def __init__(self, batch_sampler, batch_size: int, drop_last: bool, **kwargs): | |||
| """ | |||
| 可以使得 batch_sampler 对象状态恢复的 wrapper 。 | |||
| :param batch_sampler: 可迭代出 数字 或 数字列表 的可迭代对象。ReproducibleBatchSampler 将首先遍历一边该对象,然后将迭代 | |||
| 出来的序号暂存起来,使用时按照 batch_size 的 batch 大小吐出序号列表。 | |||
| :param batch_size: 每个 batch 的大小是多少。 | |||
| :param drop_last: 如果最后一个 batch 无法构成 batch_size 那么多个 sample ,是否丢掉。 | |||
| :param kwargs: fastNLP 内部使用。 | |||
| """ | |||
| self.batch_sampler = batch_sampler | |||
| self.batch_size = batch_size | |||
| self.drop_last = drop_last | |||
| self.data_idx = kwargs.get("data_idx", 0) | |||
| self._index_list = kwargs.get("_index_list", self._iterate_sampler()) | |||
| self.need_reinitialize = kwargs.get("need_reinitialize", False) | |||
| def _iterate_sampler(self): | |||
| _index_lst = [] | |||
| for idx in self.batch_sampler: | |||
| if isinstance(idx, list): | |||
| _index_lst.extend(idx) | |||
| # 说明是在初始化时传入的是一个 sampler,理论上对应于 dataloader 在初始化时没有 batch_size,也没有 batch_sampler 的情况; | |||
| else: | |||
| _index_lst.append(idx) | |||
| # 64 位机器的 unsigned int 为 4 个字节,能表示的最大大小为 4294967295; | |||
| if len(_index_lst) > 4294967295: | |||
| # 注意 self._index_list 内存放的是全部数据的 index; | |||
| # unsigned long | |||
| _index_lst = array("L", _index_lst) | |||
| else: | |||
| # unsigned int | |||
| _index_lst = array("I", _index_lst) | |||
| return _index_lst | |||
| def __iter__(self): | |||
| if self.need_reinitialize: | |||
| self._index_list = self._iterate_sampler() | |||
| self.data_idx = 0 | |||
| else: | |||
| self.need_reinitialize = True | |||
| batch = [] | |||
| if self.data_idx: | |||
| index_list = self._index_list[self.data_idx:] | |||
| else: | |||
| index_list = self._index_list | |||
| for idx in index_list: | |||
| batch.append(idx) | |||
| self.data_idx += 1 | |||
| if len(batch) == self.batch_size: | |||
| yield batch | |||
| batch = [] | |||
| if len(batch) > 0 and not self.drop_last: | |||
| yield batch | |||
| def __len__(self) -> int: | |||
| if self.drop_last: | |||
| return len(self._index_list) // self.batch_size | |||
| else: | |||
| return (len(self._index_list) + self.batch_size - 1) // self.batch_size | |||
| def state_dict(self) -> Dict: | |||
| return {"index_list": deepcopy(self._index_list), "data_idx": self.data_idx, 'sampler_type': self.__class__.__name__} | |||
| def load_state_dict(self, states: Dict): | |||
| assert states['sampler_type'] == self.__class__.__name__, f"The sampler type in checkpoint is {states['sampler_type']}," \ | |||
| f"we cannot use {self.__class__.__name__} to load it." | |||
| _index_list = states["index_list"] | |||
| assert len(_index_list) == len(self._index_list), "The number of samples is different between the checkpoint " \ | |||
| "record and current dataset." | |||
| self._index_list = _index_list | |||
| self.data_idx = states["data_idx"] | |||
| self.need_reinitialize = False | |||
| def set_distributed(self): | |||
| raise RuntimeError(f"ReproduceBatchSampler does not support to change to distributed training.") | |||
| def set_epoch(self, epoch): | |||
| if hasattr(self.batch_sampler, "sampler") and hasattr(self.batch_sampler.sampler, 'set_epoch') and callable(self.batch_sampler.sampler.set_epoch): | |||
| self.batch_sampler.sampler.set_epoch(epoch) | |||
| @property | |||
| def batch_idx_in_epoch(self): | |||
| if self.drop_last: | |||
| return len(self._index_list) // self.batch_size - (len(self._index_list) - self.data_idx) // self.batch_size | |||
| else: | |||
| return (len(self._index_list) + self.batch_size - 1) // self.batch_size - \ | |||
| (len(self._index_list) - self.data_idx + self.batch_size - 1) // self.batch_size | |||
| # todo | |||
| # class SortedSampler(ReproducibleIterator): | |||
| # def __init__(self, dataset, key): | |||
| # pass | |||
| # | |||
| # | |||
| # class BucketedSampler(ReproducibleIterator): | |||
| # def __init__(self, dataset, key): | |||
| # pass | |||
| if __name__ == "__main__": | |||
| sampler = RandomSampler(1) | |||
| print(vars(sampler)) | |||
| batch_sampler = ReproducibleBatchSampler(list(range(3)), 1, True) | |||
| print(vars(batch_sampler)) | |||
| @@ -1,6 +1,6 @@ | |||
| import unittest | |||
| from fastNLP.core.dataloaders.torch_dataloader import FDataLoader, prepare_dataloader | |||
| from fastNLP.core.dataloaders.torch_dataloader import TorchDataLoader, prepare_torch_dataloader | |||
| from fastNLP.core.dataset import DataSet | |||
| from fastNLP.io.data_bundle import DataBundle | |||
| @@ -9,17 +9,17 @@ class TestFdl(unittest.TestCase): | |||
| def test_init_v1(self): | |||
| ds = DataSet({"x": [[1, 2], [2, 3, 4], [4, 5, 6, 7]] * 10, "y": [1, 0, 1] * 10}) | |||
| fdl = FDataLoader(ds, batch_size=3, shuffle=True, drop_last=True) | |||
| fdl = TorchDataLoader(ds, batch_size=3, shuffle=True, drop_last=True) | |||
| # for batch in fdl: | |||
| # print(batch) | |||
| fdl1 = FDataLoader(ds, batch_size=3, shuffle=True, drop_last=True, as_numpy=True) | |||
| fdl1 = TorchDataLoader(ds, batch_size=3, shuffle=True, drop_last=True, as_numpy=True) | |||
| # for batch in fdl1: | |||
| # print(batch) | |||
| def test_set_padding(self): | |||
| ds = DataSet({"x": [[1, 2], [2, 3, 4], [4, 5, 6, 7]] * 10, "y": [1, 0, 1] * 10}) | |||
| ds.set_pad_val("x", val=-1) | |||
| fdl = FDataLoader(ds, batch_size=3) | |||
| fdl = TorchDataLoader(ds, batch_size=3) | |||
| fdl.set_input("x", "y") | |||
| for batch in fdl: | |||
| print(batch) | |||
| @@ -36,7 +36,7 @@ class TestFdl(unittest.TestCase): | |||
| _dict["Y"].append(ins['y']) | |||
| return _dict | |||
| fdl = FDataLoader(ds, batch_size=3, as_numpy=True) | |||
| fdl = TorchDataLoader(ds, batch_size=3, as_numpy=True) | |||
| fdl.set_input("x", "y") | |||
| fdl.add_collator(collate_fn) | |||
| for batch in fdl: | |||
| @@ -44,7 +44,7 @@ class TestFdl(unittest.TestCase): | |||
| def test_get_batch_indices(self): | |||
| ds = DataSet({"x": [[1, 2], [2, 3, 4], [4, 5, 6, 7]] * 10, "y": [1, 0, 1] * 10}) | |||
| fdl = FDataLoader(ds, batch_size=3, shuffle=True) | |||
| fdl = TorchDataLoader(ds, batch_size=3, shuffle=True) | |||
| fdl.set_input("y", "x") | |||
| for batch in fdl: | |||
| print(fdl.get_batch_indices()) | |||
| @@ -67,30 +67,30 @@ class TestFdl(unittest.TestCase): | |||
| return object.__getattribute__(self, item) | |||
| dataset = _DataSet() | |||
| dl = FDataLoader(dataset, batch_size=2, shuffle=True) | |||
| dl = TorchDataLoader(dataset, batch_size=2, shuffle=True) | |||
| # dl.set_inputs('data', 'labels') | |||
| # dl.set_pad_val('labels', val=None) | |||
| for batch in dl: | |||
| print(batch) | |||
| print(dl.get_batch_indices()) | |||
| def test_prepare_dataloader(self): | |||
| def test_prepare_torch_dataloader(self): | |||
| ds = DataSet({"x": [[1, 2], [2, 3, 4], [4, 5, 6, 7]] * 10, "y": [1, 0, 1] * 10}) | |||
| dl = prepare_dataloader(ds, batch_size=8, shuffle=True, num_workers=2) | |||
| assert isinstance(dl, FDataLoader) | |||
| dl = prepare_torch_dataloader(ds, batch_size=8, shuffle=True, num_workers=2) | |||
| assert isinstance(dl, TorchDataLoader) | |||
| ds1 = DataSet({"x": [[1, 2], [2, 3, 4], [4, 5, 6, 7]] * 10, "y": [1, 0, 1] * 10}) | |||
| dbl = DataBundle(datasets={'train': ds, 'val': ds1}) | |||
| dl_bundle = prepare_dataloader(dbl) | |||
| assert isinstance(dl_bundle['train'], FDataLoader) | |||
| assert isinstance(dl_bundle['val'], FDataLoader) | |||
| dl_bundle = prepare_torch_dataloader(dbl) | |||
| assert isinstance(dl_bundle['train'], TorchDataLoader) | |||
| assert isinstance(dl_bundle['val'], TorchDataLoader) | |||
| ds_dict = {'train_1': ds, 'val': ds1} | |||
| dl_dict = prepare_dataloader(ds_dict) | |||
| assert isinstance(dl_dict['train_1'], FDataLoader) | |||
| assert isinstance(dl_dict['val'], FDataLoader) | |||
| dl_dict = prepare_torch_dataloader(ds_dict) | |||
| assert isinstance(dl_dict['train_1'], TorchDataLoader) | |||
| assert isinstance(dl_dict['val'], TorchDataLoader) | |||
| sequence = [ds, ds1] | |||
| seq_ds = prepare_dataloader(sequence) | |||
| assert isinstance(seq_ds[0], FDataLoader) | |||
| assert isinstance(seq_ds[1], FDataLoader) | |||
| seq_ds = prepare_torch_dataloader(sequence) | |||
| assert isinstance(seq_ds[0], TorchDataLoader) | |||
| assert isinstance(seq_ds[1], TorchDataLoader) | |||
| @@ -9,7 +9,8 @@ import paddle | |||
| from paddle.io import DataLoader, BatchSampler | |||
| from fastNLP.core.drivers.paddle_driver.single_device import PaddleSingleDriver | |||
| from fastNLP.core.samplers.reproducible_sampler import ReproducibleBatchSampler, RandomSampler | |||
| from fastNLP.core.samplers.reproducible_sampler import RandomSampler | |||
| from fastNLP.core.samplers import ReproducibleBatchSampler | |||
| from tests.helpers.models.paddle_model import PaddleNormalModel_Classification | |||
| from tests.helpers.datasets.paddle_data import PaddleDataset_MNIST, PaddleRandomDataset | |||
| from fastNLP.core import synchronize_safe_rm | |||
| @@ -118,7 +118,6 @@ class TestAccuracy: | |||
| def test_v1(self, is_ddp: bool, dataset: DataSet, metric_class: Type['Metric'], | |||
| metric_kwargs: Dict[str, Any]) -> None: | |||
| global pool | |||
| print(pool) | |||
| if is_ddp: | |||
| if sys.platform == "win32": | |||
| pytest.skip("DDP not supported on windows") | |||
| @@ -14,6 +14,7 @@ from torch.multiprocessing import Pool, set_start_method | |||
| from fastNLP.core.vocabulary import Vocabulary | |||
| from fastNLP.core.metrics import SpanFPreRecMetric | |||
| from fastNLP.core.dataset import DataSet | |||
| set_start_method("spawn", force=True) | |||
| @@ -45,7 +46,6 @@ def setup_ddp(rank: int, world_size: int, master_port: int) -> None: | |||
| os.environ["MASTER_ADDR"] = "localhost" | |||
| os.environ["MASTER_PORT"] = str(master_port) | |||
| print(torch.cuda.device_count()) | |||
| if torch.distributed.is_available() and sys.platform not in ("win32", "cygwin"): | |||
| torch.distributed.init_process_group("gloo", rank=rank, world_size=world_size) | |||
| @@ -64,15 +64,15 @@ def find_free_network_port() -> int: | |||
| return port | |||
| @pytest.fixture(scope='class', autouse=True) | |||
| def pre_process(): | |||
| global pool | |||
| pool = Pool(processes=NUM_PROCESSES) | |||
| master_port = find_free_network_port() | |||
| pool.starmap(setup_ddp, [(rank, NUM_PROCESSES, master_port) for rank in range(NUM_PROCESSES)]) | |||
| yield | |||
| pool.close() | |||
| pool.join() | |||
| # @pytest.fixture(scope='class', autouse=True) | |||
| # def pre_process(): | |||
| # global pool | |||
| # pool = Pool(processes=NUM_PROCESSES) | |||
| # master_port = find_free_network_port() | |||
| # pool.starmap(setup_ddp, [(rank, NUM_PROCESSES, master_port) for rank in range(NUM_PROCESSES)]) | |||
| # yield | |||
| # pool.close() | |||
| # pool.join() | |||
| def _test(local_rank: int, | |||
| @@ -87,18 +87,19 @@ def _test(local_rank: int, | |||
| # dataset 也类似(每个进程有自己的一个) | |||
| dataset = copy.deepcopy(dataset) | |||
| metric.to(device) | |||
| print(os.environ.get("MASTER_PORT", "xx")) | |||
| # 把数据拆到每个 GPU 上,有点模仿 DistributedSampler 的感觉,但这里数据单位是一个 batch(即每个 i 取了一个 batch 到自己的 GPU 上) | |||
| for i in range(local_rank, len(dataset), world_size): | |||
| pred, tg, seq_len = dataset[i]['pred'].to(device), dataset[i]['tg'].to(device), dataset[i]['seq_len'] | |||
| print(tg, seq_len) | |||
| metric.update(pred, tg, seq_len) | |||
| my_result = metric.get_metric() | |||
| print(my_result) | |||
| print(sklearn_metric) | |||
| assert my_result == sklearn_metric | |||
| class SpanFPreRecMetricTest(unittest.TestCase): | |||
| global pool | |||
| def test_case1(self): | |||
| from fastNLP.core.metrics.span_f1_pre_rec_metric import _bmes_tag_to_spans | |||
| @@ -147,26 +148,26 @@ class SpanFPreRecMetricTest(unittest.TestCase): | |||
| -1.3508, -0.9513], | |||
| [1.8948, 0.8627, -2.1359, 1.3740, -0.7499, 1.5019, 0.6919, | |||
| -0.0842, -0.4294]], | |||
| [[-0.2802, 0.6941, -0.4788, -0.3845, 1.7752, 1.2950, -1.9490, | |||
| -1.4138, -0.8853], | |||
| [-1.3752, -0.5457, -0.5305, 0.4018, 0.2934, 0.7931, 2.3845, | |||
| -1.0726, 0.0364], | |||
| [0.3621, 0.2609, 0.1269, -0.5950, 0.7212, 0.5959, 1.6264, | |||
| -0.8836, -0.9320], | |||
| [0.2003, -1.0758, -1.1560, -0.6472, -1.7549, 0.1264, 0.6044, | |||
| -1.6857, 1.1571], | |||
| [1.4277, -0.4915, 0.4496, 2.2027, 0.0730, -3.1792, -0.5125, | |||
| -0.5837, 1.0184], | |||
| [1.9495, 1.7145, -0.2143, -0.1230, -0.2205, 0.8250, 0.4943, | |||
| -0.9025, 0.0864]]]) | |||
| bio_target = torch.LongTensor([[3, 6, 0, 8, 2, 4], | |||
| [4, 1, 7, 0, 4, 7]]) | |||
| [ | |||
| [[-0.2802, 0.6941, -0.4788, -0.3845, 1.7752, 1.2950, -1.9490, | |||
| -1.4138, -0.8853], | |||
| [-1.3752, -0.5457, -0.5305, 0.4018, 0.2934, 0.7931, 2.3845, | |||
| -1.0726, 0.0364], | |||
| [0.3621, 0.2609, 0.1269, -0.5950, 0.7212, 0.5959, 1.6264, | |||
| -0.8836, -0.9320], | |||
| [0.2003, -1.0758, -1.1560, -0.6472, -1.7549, 0.1264, 0.6044, | |||
| -1.6857, 1.1571], | |||
| [1.4277, -0.4915, 0.4496, 2.2027, 0.0730, -3.1792, -0.5125, | |||
| -0.5837, 1.0184], | |||
| [1.9495, 1.7145, -0.2143, -0.1230, -0.2205, 0.8250, 0.4943, | |||
| -0.9025, 0.0864]] | |||
| ] | |||
| ]) | |||
| bio_target = torch.LongTensor([[3, 6, 0, 8, 2, 4], [4, 1, 7, 0, 4, 7]]) | |||
| fastnlp_bio_metric.update(bio_sequence, bio_target, [6, 6]) | |||
| expect_bio_res = {'pre-1': 0.333333, 'rec-1': 0.333333, 'f-1': 0.333333, 'pre-2': 0.5, 'rec-2': 0.5, | |||
| 'f-2': 0.5, 'pre-0': 0.0, 'rec-0': 0.0, 'f-0': 0.0, 'pre-3': 0.0, 'rec-3': 0.0, | |||
| 'f-3': 0.0, 'pre': 0.222222, 'rec': 0.181818, 'f': 0.2} | |||
| assert expect_bio_res == fastnlp_bio_metric.get_metric() | |||
| # print(fastnlp_bio_metric.get_metric()) | |||
| @@ -325,44 +326,52 @@ class SpanFPreRecMetricTest(unittest.TestCase): | |||
| metric = SpanFPreRecMetric(tag_vocab=vocab, encoding_type='bmeso') | |||
| def test_case5(self): | |||
| global pool | |||
| # pool = Pool(NUM_PROCESSES) | |||
| # master_port = find_free_network_port() | |||
| # pool.starmap(setup_ddp, [(rank, NUM_PROCESSES, master_port) for rank in range(NUM_PROCESSES)]) | |||
| # global pool | |||
| pool = Pool(NUM_PROCESSES) | |||
| master_port = find_free_network_port() | |||
| pool.starmap(setup_ddp, [(rank, NUM_PROCESSES, master_port) for rank in range(NUM_PROCESSES)]) | |||
| number_labels = 4 | |||
| # bio tag | |||
| fastnlp_bio_vocab = Vocabulary(unknown=None, padding=None) | |||
| fastnlp_bio_vocab.word_count = Counter(_generate_tags('BIO', number_labels)) | |||
| # fastnlp_bio_metric = SpanFPreRecMetric(tag_vocab=fastnlp_bio_vocab, only_gross=False) | |||
| dataset = DataSet({'pred': [torch.FloatTensor( | |||
| [[[-0.4424, -0.4579, -0.7376, 1.8129, 0.1316, 1.6566, -1.2169, | |||
| -0.3782, 0.8240], | |||
| [-1.2348, -0.1876, -0.1462, -0.4834, -0.6692, -0.9735, 1.1563, | |||
| -0.3562, -1.4116], | |||
| [1.6550, -0.9555, 0.3782, -1.3160, -1.5835, -0.3443, -1.7858, | |||
| 2.0023, 0.7075], | |||
| [-0.3772, -0.5447, -1.5631, 1.1614, 1.4598, -1.2764, 0.5186, | |||
| 0.3832, -0.1540], | |||
| [-0.1011, 0.0600, 1.1090, -0.3545, 0.1284, 1.1484, -1.0120, | |||
| -1.3508, -0.9513], | |||
| [1.8948, 0.8627, -2.1359, 1.3740, -0.7499, 1.5019, 0.6919, | |||
| -0.0842, -0.4294]], | |||
| [[-0.2802, 0.6941, -0.4788, -0.3845, 1.7752, 1.2950, -1.9490, | |||
| -1.4138, -0.8853], | |||
| [-1.3752, -0.5457, -0.5305, 0.4018, 0.2934, 0.7931, 2.3845, | |||
| -1.0726, 0.0364], | |||
| [0.3621, 0.2609, 0.1269, -0.5950, 0.7212, 0.5959, 1.6264, | |||
| -0.8836, -0.9320], | |||
| [0.2003, -1.0758, -1.1560, -0.6472, -1.7549, 0.1264, 0.6044, | |||
| -1.6857, 1.1571], | |||
| [1.4277, -0.4915, 0.4496, 2.2027, 0.0730, -3.1792, -0.5125, | |||
| -0.5837, 1.0184], | |||
| [1.9495, 1.7145, -0.2143, -0.1230, -0.2205, 0.8250, 0.4943, | |||
| -0.9025, 0.0864]]])] * 100, | |||
| 'tg': [torch.LongTensor([[3, 6, 0, 8, 2, 4], | |||
| [4, 1, 7, 0, 4, 7]])] * 100, | |||
| 'seq_len': [[6, 6]] * 100}) | |||
| dataset = DataSet({'pred': [ | |||
| torch.FloatTensor([[[-0.4424, -0.4579, -0.7376, 1.8129, 0.1316, 1.6566, -1.2169, | |||
| -0.3782, 0.8240], | |||
| [-1.2348, -0.1876, -0.1462, -0.4834, -0.6692, -0.9735, 1.1563, | |||
| -0.3562, -1.4116], | |||
| [1.6550, -0.9555, 0.3782, -1.3160, -1.5835, -0.3443, -1.7858, | |||
| 2.0023, 0.7075], | |||
| [-0.3772, -0.5447, -1.5631, 1.1614, 1.4598, -1.2764, 0.5186, | |||
| 0.3832, -0.1540], | |||
| [-0.1011, 0.0600, 1.1090, -0.3545, 0.1284, 1.1484, -1.0120, | |||
| -1.3508, -0.9513], | |||
| [1.8948, 0.8627, -2.1359, 1.3740, -0.7499, 1.5019, 0.6919, | |||
| -0.0842, -0.4294]] | |||
| ]), | |||
| torch.FloatTensor([ | |||
| [[-0.2802, 0.6941, -0.4788, -0.3845, 1.7752, 1.2950, -1.9490, | |||
| -1.4138, -0.8853], | |||
| [-1.3752, -0.5457, -0.5305, 0.4018, 0.2934, 0.7931, 2.3845, | |||
| -1.0726, 0.0364], | |||
| [0.3621, 0.2609, 0.1269, -0.5950, 0.7212, 0.5959, 1.6264, | |||
| -0.8836, -0.9320], | |||
| [0.2003, -1.0758, -1.1560, -0.6472, -1.7549, 0.1264, 0.6044, | |||
| -1.6857, 1.1571], | |||
| [1.4277, -0.4915, 0.4496, 2.2027, 0.0730, -3.1792, -0.5125, | |||
| -0.5837, 1.0184], | |||
| [1.9495, 1.7145, -0.2143, -0.1230, -0.2205, 0.8250, 0.4943, | |||
| -0.9025, 0.0864]] | |||
| ]) | |||
| ], | |||
| 'tg': [ | |||
| torch.LongTensor([[3, 6, 0, 8, 2, 4]]), | |||
| torch.LongTensor([[4, 1, 7, 0, 4, 7]]) | |||
| ], | |||
| 'seq_len': [ | |||
| [6], [6] | |||
| ]}) | |||
| metric_kwargs = { | |||
| 'tag_vocab': fastnlp_bio_vocab, | |||
| 'only_gross': False, | |||
| @@ -372,7 +381,6 @@ class SpanFPreRecMetricTest(unittest.TestCase): | |||
| 'f-2': 0.5, 'pre-0': 0.0, 'rec-0': 0.0, 'f-0': 0.0, 'pre-3': 0.0, 'rec-3': 0.0, | |||
| 'f-3': 0.0, 'pre': 0.222222, 'rec': 0.181818, 'f': 0.2} | |||
| processes = NUM_PROCESSES | |||
| print(torch.cuda.device_count()) | |||
| pool.starmap( | |||
| partial( | |||
| @@ -384,3 +392,5 @@ class SpanFPreRecMetricTest(unittest.TestCase): | |||
| ), | |||
| [(rank, processes, torch.device(f'cuda:{rank}')) for rank in range(processes)] | |||
| ) | |||
| pool.close() | |||
| pool.join() | |||
| @@ -0,0 +1,438 @@ | |||
| from array import array | |||
| import numpy as np | |||
| import pytest | |||
| from itertools import chain | |||
| from fastNLP.core.samplers import ReproducibleBatchSampler, BucketedBatchSampler | |||
| from fastNLP.core.drivers.torch_driver.utils import replace_batch_sampler | |||
| from tests.helpers.datasets.torch_data import TorchNormalDataset | |||
| class TestReproducibleBatchSampler: | |||
| # TODO 拆分测试,在这里只测试一个东西 | |||
| def test_torch_dataloader_1(self): | |||
| import torch | |||
| from torch.utils.data import DataLoader | |||
| # no shuffle | |||
| before_batch_size = 7 | |||
| dataset = TorchNormalDataset(num_of_data=100) | |||
| dataloader = DataLoader(dataset, batch_size=before_batch_size) | |||
| re_batchsampler = ReproducibleBatchSampler(dataloader.batch_sampler, dataloader.batch_size, drop_last=False) | |||
| dataloader = replace_batch_sampler(dataloader, re_batchsampler) | |||
| forward_steps = 3 | |||
| iter_dataloader = iter(dataloader) | |||
| for _ in range(forward_steps): | |||
| next(iter_dataloader) | |||
| # 1. 保存状态 | |||
| _get_re_batchsampler = dataloader.batch_sampler | |||
| assert isinstance(_get_re_batchsampler, ReproducibleBatchSampler) | |||
| state = _get_re_batchsampler.state_dict() | |||
| assert state == {"index_list": array("I", list(range(100))), "data_idx": forward_steps*before_batch_size, | |||
| "sampler_type": "ReproducibleBatchSampler"} | |||
| # 2. 断点重训,重新生成一个 dataloader; | |||
| # 不改变 batch_size; | |||
| dataloader = DataLoader(dataset, batch_size=before_batch_size) | |||
| re_batchsampler = ReproducibleBatchSampler(dataloader.batch_sampler, dataloader.batch_size, drop_last=False) | |||
| re_batchsampler.load_state_dict(state) | |||
| dataloader = replace_batch_sampler(dataloader, re_batchsampler) | |||
| real_res = [] | |||
| supposed_res = (torch.tensor(list(range(21, 28))), torch.tensor(list(range(28, 35)))) | |||
| forward_steps = 2 | |||
| iter_dataloader = iter(dataloader) | |||
| for _ in range(forward_steps): | |||
| real_res.append(next(iter_dataloader)) | |||
| for i in range(forward_steps): | |||
| assert all(real_res[i] == supposed_res[i]) | |||
| # 改变 batch_size; | |||
| after_batch_size = 3 | |||
| dataloader = DataLoader(dataset, batch_size=after_batch_size) | |||
| re_batchsampler = ReproducibleBatchSampler(dataloader.batch_sampler, dataloader.batch_size, drop_last=False) | |||
| re_batchsampler.load_state_dict(state) | |||
| dataloader = replace_batch_sampler(dataloader, re_batchsampler) | |||
| real_res = [] | |||
| supposed_res = (torch.tensor(list(range(21, 24))), torch.tensor(list(range(24, 27)))) | |||
| forward_steps = 2 | |||
| iter_dataloader = iter(dataloader) | |||
| for _ in range(forward_steps): | |||
| real_res.append(next(iter_dataloader)) | |||
| for i in range(forward_steps): | |||
| assert all(real_res[i] == supposed_res[i]) | |||
| # 断点重训的第二轮是否是一个完整的 dataloader; | |||
| # 先把断点重训所在的那一个 epoch 跑完; | |||
| begin_idx = 27 | |||
| while True: | |||
| try: | |||
| data = next(iter_dataloader) | |||
| _batch_size = len(data) | |||
| assert all(data == torch.tensor(list(range(begin_idx, begin_idx + _batch_size)))) | |||
| begin_idx += _batch_size | |||
| except StopIteration: | |||
| break | |||
| # 开始新的一轮; | |||
| begin_idx = 0 | |||
| iter_dataloader = iter(dataloader) | |||
| while True: | |||
| try: | |||
| data = next(iter_dataloader) | |||
| _batch_size = len(data) | |||
| assert all(data == torch.tensor(list(range(begin_idx, begin_idx + _batch_size)))) | |||
| begin_idx += _batch_size | |||
| except StopIteration: | |||
| break | |||
| def test_torch_dataloader_2(self): | |||
| # 测试新的一轮的 index list 是重新生成的,而不是沿用上一轮的; | |||
| from torch.utils.data import DataLoader | |||
| # no shuffle | |||
| before_batch_size = 7 | |||
| dataset = TorchNormalDataset(num_of_data=100) | |||
| # 开启 shuffle,来检验断点重训后的第二轮的 index list 是不是重新生成的; | |||
| dataloader = DataLoader(dataset, batch_size=before_batch_size, shuffle=True) | |||
| re_batchsampler = ReproducibleBatchSampler(dataloader.batch_sampler, dataloader.batch_size, drop_last=False) | |||
| dataloader = replace_batch_sampler(dataloader, re_batchsampler) | |||
| # 将一轮的所有数据保存下来,看是否恢复的是正确的; | |||
| all_supposed_data = [] | |||
| forward_steps = 3 | |||
| iter_dataloader = iter(dataloader) | |||
| for _ in range(forward_steps): | |||
| all_supposed_data.extend(next(iter_dataloader).tolist()) | |||
| # 1. 保存状态 | |||
| _get_re_batchsampler = dataloader.batch_sampler | |||
| assert isinstance(_get_re_batchsampler, ReproducibleBatchSampler) | |||
| state = _get_re_batchsampler.state_dict() | |||
| # 2. 断点重训,重新生成一个 dataloader; | |||
| # 不改变 batch_size; | |||
| dataloader = DataLoader(dataset, batch_size=before_batch_size, shuffle=True) | |||
| re_batchsampler = ReproducibleBatchSampler(dataloader.batch_sampler, dataloader.batch_size, drop_last=False) | |||
| re_batchsampler.load_state_dict(state) | |||
| dataloader = replace_batch_sampler(dataloader, re_batchsampler) | |||
| # 先把这一轮的数据过完; | |||
| pre_index_list = dataloader.batch_sampler.state_dict()["index_list"] | |||
| while True: | |||
| try: | |||
| all_supposed_data.extend(next(iter_dataloader).tolist()) | |||
| except StopIteration: | |||
| break | |||
| assert all_supposed_data == list(pre_index_list) | |||
| # 重新开启新的一轮; | |||
| for _ in range(3): | |||
| iter_dataloader = iter(dataloader) | |||
| res = [] | |||
| while True: | |||
| try: | |||
| res.append(next(iter_dataloader)) | |||
| except StopIteration: | |||
| break | |||
| def test_3(self): | |||
| import torch | |||
| from torch.utils.data import DataLoader | |||
| before_batch_size = 7 | |||
| dataset = TorchNormalDataset(num_of_data=100) | |||
| # 开启 shuffle,来检验断点重训后的第二轮的 index list 是不是重新生成的; | |||
| dataloader = DataLoader(dataset, batch_size=before_batch_size) | |||
| for idx, data in enumerate(dataloader): | |||
| if idx > 3: | |||
| break | |||
| iterator = iter(dataloader) | |||
| for each in iterator: | |||
| pass | |||
| class DatasetWithVaryLength: | |||
| def __init__(self, num_of_data=100): | |||
| self.data = np.arange(num_of_data) | |||
| def __getitem__(self, item): | |||
| return self.data[item] | |||
| def __len__(self): | |||
| return len(self.data) | |||
| class TestBucketedBatchSampler: | |||
| @pytest.mark.parametrize('shuffle', [True, False]) | |||
| @pytest.mark.parametrize('drop_last', [True, False]) | |||
| @pytest.mark.parametrize('num', [2, 7, 14, 15, 70, 71]) | |||
| def test_single_num_batch(self, shuffle, drop_last, num): | |||
| # 数量不够不报错 | |||
| for num in [2, 7, 14, 15, 70, 71]: | |||
| dataset = DatasetWithVaryLength(num_of_data=num) | |||
| before_batch_size = 7 | |||
| re_batchsampler = BucketedBatchSampler(dataset, length=dataset.data, batch_size=before_batch_size, | |||
| num_batch_per_bucket=10, drop_last=drop_last, | |||
| shuffle=shuffle) | |||
| count = len(list(iter(re_batchsampler))) | |||
| if drop_last: | |||
| assert count==num//before_batch_size, num | |||
| else: | |||
| assert count==(num+before_batch_size-1)//before_batch_size, num | |||
| @pytest.mark.parametrize('shuffle', [True, False]) | |||
| @pytest.mark.parametrize('drop_last', [True, False]) | |||
| def test_single(self, shuffle, drop_last): | |||
| before_batch_size = 7 | |||
| num_batch_per_bucket = 4 # 那么任意 batch 内的长度差值不应该超过4 | |||
| dataset = DatasetWithVaryLength(num_of_data=1000) | |||
| re_batchsampler = BucketedBatchSampler(dataset, length=dataset.data, batch_size=before_batch_size, | |||
| num_batch_per_bucket=num_batch_per_bucket, drop_last=drop_last, | |||
| shuffle=shuffle) | |||
| re_batchsampler.set_epoch(0) | |||
| forward_steps = 10 | |||
| iterator = iter(re_batchsampler) | |||
| already_generate_indices = set() | |||
| for _ in range(forward_steps): | |||
| batch = next(iterator) | |||
| assert max(batch) - min(batch) <= before_batch_size * num_batch_per_bucket | |||
| already_generate_indices.update(batch) | |||
| # 1. 保存状态 | |||
| state = re_batchsampler.state_dict() | |||
| # 2. 断点重训,继续训练 | |||
| re_batchsampler2 = BucketedBatchSampler(dataset, length=dataset.data, batch_size=before_batch_size, | |||
| num_batch_per_bucket=num_batch_per_bucket, drop_last=drop_last, | |||
| shuffle=shuffle) | |||
| re_batchsampler2.load_state_dict(state) | |||
| re_batchsampler2.set_epoch(0) | |||
| new_already_generate_indices = set() | |||
| mask = np.ones(len(dataset), dtype=bool) | |||
| mask[list(already_generate_indices)] = 0 | |||
| indices = np.arange(len(dataset))[mask] | |||
| max_diff = -1 | |||
| for i in range(len(indices)-before_batch_size * num_batch_per_bucket): | |||
| max_diff = max(max_diff, indices[i+before_batch_size * num_batch_per_bucket]-indices[i]) | |||
| for batch in re_batchsampler2: | |||
| assert max(batch) - min(batch) <= max_diff | |||
| for b in batch: | |||
| assert b not in already_generate_indices | |||
| new_already_generate_indices.update(batch) | |||
| if drop_last is False: | |||
| assert len(new_already_generate_indices.union(already_generate_indices))==len(dataset) | |||
| # 改变 batch_size; | |||
| after_batch_size = 3 | |||
| re_batchsampler3 = BucketedBatchSampler(dataset, length=dataset.data, batch_size=after_batch_size, | |||
| num_batch_per_bucket=num_batch_per_bucket, drop_last=drop_last, | |||
| shuffle=shuffle) | |||
| re_batchsampler3.load_state_dict(state) | |||
| re_batchsampler3.set_epoch(0) | |||
| count = 0 | |||
| mask = np.ones(len(dataset), dtype=bool) | |||
| mask[list(already_generate_indices)] = 0 | |||
| indices = np.arange(len(dataset))[mask] | |||
| max_diff = -1 | |||
| for i in range(len(indices)-after_batch_size * num_batch_per_bucket): | |||
| max_diff = max(max_diff, indices[i+after_batch_size * num_batch_per_bucket]-indices[i]) | |||
| for batch in re_batchsampler3: | |||
| assert max(batch) - min(batch) <= max_diff | |||
| for b in batch: | |||
| assert b not in already_generate_indices | |||
| already_generate_indices.update(batch) | |||
| count += 1 | |||
| if count > 5: | |||
| break | |||
| # 再 save ,不允许再上个epoch没结束继续sample | |||
| after_batch_size = 5 | |||
| with pytest.raises(RuntimeError): | |||
| state = re_batchsampler3.state_dict() | |||
| for batch in re_batchsampler3: # consume all, 这样才能save | |||
| pass | |||
| already_generate_indices = set() | |||
| count = 0 | |||
| for batch in re_batchsampler3: # 重新开始 | |||
| assert max(batch) - min(batch) <= max_diff | |||
| for b in batch: | |||
| assert b not in already_generate_indices | |||
| already_generate_indices.update(batch) | |||
| count += 1 | |||
| if count > 5: | |||
| break | |||
| state = re_batchsampler3.state_dict() | |||
| # 这里的 drop_last 为 False,需要最终是所有 sample | |||
| re_batchsampler4 = BucketedBatchSampler(dataset, length=dataset.data, batch_size=after_batch_size, | |||
| num_batch_per_bucket=num_batch_per_bucket, drop_last=False, | |||
| shuffle=shuffle) | |||
| re_batchsampler4.load_state_dict(state) | |||
| re_batchsampler4.set_epoch(0) | |||
| mask = np.ones(len(dataset), dtype=bool) | |||
| mask[list(already_generate_indices)] = 0 | |||
| indices = np.arange(len(dataset))[mask] | |||
| max_diff = -1 | |||
| for i in range(len(indices) - after_batch_size * num_batch_per_bucket): | |||
| max_diff = max(max_diff, indices[i + after_batch_size * num_batch_per_bucket] - indices[i]) | |||
| for batch in re_batchsampler4: | |||
| assert max(batch) - min(batch) <= max_diff | |||
| for b in batch: | |||
| assert b not in already_generate_indices | |||
| already_generate_indices.update(batch) | |||
| assert len(already_generate_indices) == len(dataset) | |||
| @pytest.mark.parametrize('shuffle', [True, False]) | |||
| @pytest.mark.parametrize('drop_last', [True, False]) | |||
| @pytest.mark.parametrize('pad', [True, False]) | |||
| def test_multi(self, shuffle, drop_last, pad): | |||
| # def test_multi(self, shuffle=True, drop_last=False, pad=False): | |||
| # no shuffle | |||
| num_replica = 2 | |||
| dataset = DatasetWithVaryLength(num_of_data=1000) | |||
| batch_size = 5 | |||
| num_batch_per_bucket = 10 | |||
| lengths = [] | |||
| rank0_already_seen_indexes = None | |||
| max_diff = num_batch_per_bucket * batch_size * num_replica | |||
| for rank in range(num_replica): | |||
| sampler = BucketedBatchSampler(dataset, length=dataset.data, batch_size = batch_size, | |||
| num_batch_per_bucket = num_batch_per_bucket, | |||
| shuffle = shuffle, drop_last=drop_last) | |||
| sampler.set_epoch(0) | |||
| sampler.set_distributed(num_replica, rank=rank, pad=pad) | |||
| lengths.append(len(sampler)) | |||
| already_seen_indexes = set() | |||
| repeat_count = 0 | |||
| for batch in sampler: | |||
| assert max_diff>=max(batch)-min(batch) | |||
| for b in batch: | |||
| repeat_count += int(b in already_seen_indexes) | |||
| if rank0_already_seen_indexes: # 不能交叉出现 | |||
| assert b not in rank0_already_seen_indexes | |||
| already_seen_indexes.update(batch) | |||
| if rank0_already_seen_indexes is None: | |||
| rank0_already_seen_indexes = already_seen_indexes | |||
| if pad: # 应该允许重复一次 | |||
| assert repeat_count<=1 | |||
| else: | |||
| assert repeat_count==0 | |||
| assert len(set(lengths))==1, lengths # 每个进程的batch数量一致 | |||
| # 多进程的保存 | |||
| already_seen_indexes = set() | |||
| for rank in range(num_replica): | |||
| sampler = BucketedBatchSampler(dataset, length=dataset.data, batch_size = batch_size, | |||
| num_batch_per_bucket = num_batch_per_bucket, | |||
| shuffle = shuffle, drop_last=drop_last) | |||
| sampler.set_epoch(0) | |||
| sampler.set_distributed(num_replica, rank=rank, pad=pad) | |||
| lengths.append(len(sampler)) | |||
| count = 0 | |||
| for batch in sampler: | |||
| assert max_diff>=max(batch)-min(batch) | |||
| already_seen_indexes.update(batch) | |||
| if count>5: | |||
| break | |||
| count += 1 | |||
| state = sampler.state_dict() | |||
| # 切换成单机 | |||
| new_batch_size = 6 | |||
| num_batch_per_bucket = 3 | |||
| new_sampler = BucketedBatchSampler(dataset, length=dataset.data, batch_size=new_batch_size, | |||
| num_batch_per_bucket=num_batch_per_bucket, | |||
| shuffle=shuffle, drop_last=drop_last) | |||
| new_sampler.load_state_dict(state) | |||
| repeat_count = 0 | |||
| new_already_seen_indexes = set(list(already_seen_indexes)) | |||
| mask = np.ones(len(dataset), dtype=bool) | |||
| mask[list(already_seen_indexes)] = 0 | |||
| indices = np.arange(len(dataset))[mask] | |||
| max_diff = -1 | |||
| for i in range(len(indices)-new_batch_size * num_batch_per_bucket): | |||
| max_diff = max(max_diff, indices[i+new_batch_size * num_batch_per_bucket]-indices[i]) | |||
| for batch in new_sampler: | |||
| assert max_diff>=max(batch)-min(batch) | |||
| for b in batch: | |||
| repeat_count += int(b in new_already_seen_indexes) | |||
| new_already_seen_indexes.update(batch) | |||
| if pad: # 应该允许重复一次 | |||
| assert repeat_count <= 1 | |||
| else: | |||
| assert repeat_count == 0 | |||
| if drop_last is False: # 如果没有drop应该相等 | |||
| assert len(new_already_seen_indexes)==len(dataset) | |||
| # 测试替换卡的数量。 | |||
| num_replica = 3 | |||
| new_sampler = BucketedBatchSampler(dataset, length=dataset.data, batch_size=new_batch_size, | |||
| num_batch_per_bucket=num_batch_per_bucket, | |||
| shuffle=shuffle, drop_last=drop_last) | |||
| new_sampler.set_epoch(0) | |||
| new_sampler.load_state_dict(state) | |||
| new_sampler.set_distributed(num_replicas=num_replica, rank=1, pad=pad) | |||
| repeat_count = 0 | |||
| mask = np.ones(len(dataset), dtype=bool) | |||
| mask[list(already_seen_indexes)] = 0 | |||
| indices = np.arange(len(dataset))[mask] | |||
| max_diff = -1 | |||
| for i in range(len(indices) - new_batch_size * num_batch_per_bucket*num_replica): | |||
| max_diff = max(max_diff, indices[i + new_batch_size * num_batch_per_bucket*num_replica] - indices[i]) | |||
| for batch in new_sampler: | |||
| assert max_diff>=max(batch)-min(batch) | |||
| for b in batch: | |||
| repeat_count += int(b in already_seen_indexes) | |||
| if pad: # 应该允许重复一次 | |||
| assert repeat_count <= 1 | |||
| else: | |||
| assert repeat_count == 0 | |||
| @pytest.mark.parametrize('shuffle', [True, False]) | |||
| @pytest.mark.parametrize('drop_last', [True, False]) | |||
| @pytest.mark.parametrize('pad', [True, False]) | |||
| @pytest.mark.parametrize('num_samples', [13, 100, 623, 1000]) | |||
| @pytest.mark.parametrize('num_replica', [2, 3]) | |||
| def test_multi_same_bucket(self, shuffle, drop_last, pad, num_samples, num_replica): | |||
| # def test_multi_same_bucket(self, shuffle=True, drop_last=True, pad=True, num_samples=623, num_replica=2): | |||
| dataset = DatasetWithVaryLength(num_of_data=num_samples) | |||
| batch_size = 6 | |||
| if num_replica*batch_size > num_samples: | |||
| return | |||
| num_batch_per_bucket = 10 | |||
| samplers = [] | |||
| lengths = [] | |||
| for i in range(num_replica): | |||
| sampler = BucketedBatchSampler(dataset, length=dataset.data, batch_size=batch_size, | |||
| num_batch_per_bucket=num_batch_per_bucket, shuffle=shuffle, drop_last=drop_last) | |||
| sampler.set_distributed(num_replica, rank=i, pad=pad) | |||
| sampler.set_epoch(0) | |||
| samplers.append(sampler) | |||
| lengths.append(len(list(iter(sampler)))) | |||
| assert len(set(lengths))==1 | |||
| bucket_diff = batch_size * num_batch_per_bucket * num_replica | |||
| for bs in zip(*samplers): | |||
| diff = max(chain(*bs)) - min(chain(*bs)) | |||
| assert diff <= bucket_diff | |||
| @@ -6,7 +6,7 @@ import numpy as np | |||
| from functools import partial | |||
| from array import array | |||
| from fastNLP.core.samplers.reproducible_sampler import RandomSampler, ReproducibleBatchSampler | |||
| from fastNLP.core.samplers.reproducible_sampler import RandomSampler | |||
| from fastNLP.core.drivers.torch_driver.utils import replace_batch_sampler | |||
| from tests.helpers.datasets.torch_data import TorchNormalDataset | |||
| @@ -361,148 +361,3 @@ class TestRandomSampler(unittest.TestCase): | |||
| class TestReproducibleBatchSampler: | |||
| def test_torch_dataloader_1(self): | |||
| import torch | |||
| from torch.utils.data import DataLoader | |||
| # no shuffle | |||
| before_batch_size = 7 | |||
| dataset = TorchNormalDataset(num_of_data=100) | |||
| dataloader = DataLoader(dataset, batch_size=before_batch_size) | |||
| re_batchsampler = ReproducibleBatchSampler(dataloader.batch_sampler, dataloader.batch_size, drop_last=False) | |||
| dataloader = replace_batch_sampler(dataloader, re_batchsampler) | |||
| forward_steps = 3 | |||
| iter_dataloader = iter(dataloader) | |||
| for _ in range(forward_steps): | |||
| next(iter_dataloader) | |||
| # 1. 保存状态 | |||
| _get_re_batchsampler = dataloader.batch_sampler | |||
| assert isinstance(_get_re_batchsampler, ReproducibleBatchSampler) | |||
| state = _get_re_batchsampler.state_dict() | |||
| assert state == {"index_list": array("I", list(range(100))), "data_idx": forward_steps*before_batch_size, | |||
| "sampler_type": "ReproducibleBatchSampler"} | |||
| # 2. 断点重训,重新生成一个 dataloader; | |||
| # 不改变 batch_size; | |||
| dataloader = DataLoader(dataset, batch_size=before_batch_size) | |||
| re_batchsampler = ReproducibleBatchSampler(dataloader.batch_sampler, dataloader.batch_size, drop_last=False) | |||
| re_batchsampler.load_state_dict(state) | |||
| dataloader = replace_batch_sampler(dataloader, re_batchsampler) | |||
| real_res = [] | |||
| supposed_res = (torch.tensor(list(range(21, 28))), torch.tensor(list(range(28, 35)))) | |||
| forward_steps = 2 | |||
| iter_dataloader = iter(dataloader) | |||
| for _ in range(forward_steps): | |||
| real_res.append(next(iter_dataloader)) | |||
| for i in range(forward_steps): | |||
| assert all(real_res[i] == supposed_res[i]) | |||
| # 改变 batch_size; | |||
| after_batch_size = 3 | |||
| dataloader = DataLoader(dataset, batch_size=after_batch_size) | |||
| re_batchsampler = ReproducibleBatchSampler(dataloader.batch_sampler, dataloader.batch_size, drop_last=False) | |||
| re_batchsampler.load_state_dict(state) | |||
| dataloader = replace_batch_sampler(dataloader, re_batchsampler) | |||
| real_res = [] | |||
| supposed_res = (torch.tensor(list(range(21, 24))), torch.tensor(list(range(24, 27)))) | |||
| forward_steps = 2 | |||
| iter_dataloader = iter(dataloader) | |||
| for _ in range(forward_steps): | |||
| real_res.append(next(iter_dataloader)) | |||
| for i in range(forward_steps): | |||
| assert all(real_res[i] == supposed_res[i]) | |||
| # 断点重训的第二轮是否是一个完整的 dataloader; | |||
| # 先把断点重训所在的那一个 epoch 跑完; | |||
| begin_idx = 27 | |||
| while True: | |||
| try: | |||
| data = next(iter_dataloader) | |||
| _batch_size = len(data) | |||
| assert all(data == torch.tensor(list(range(begin_idx, begin_idx + _batch_size)))) | |||
| begin_idx += _batch_size | |||
| except StopIteration: | |||
| break | |||
| # 开始新的一轮; | |||
| begin_idx = 0 | |||
| iter_dataloader = iter(dataloader) | |||
| while True: | |||
| try: | |||
| data = next(iter_dataloader) | |||
| _batch_size = len(data) | |||
| assert all(data == torch.tensor(list(range(begin_idx, begin_idx + _batch_size)))) | |||
| begin_idx += _batch_size | |||
| except StopIteration: | |||
| break | |||
| def test_torch_dataloader_2(self): | |||
| # 测试新的一轮的 index list 是重新生成的,而不是沿用上一轮的; | |||
| from torch.utils.data import DataLoader | |||
| # no shuffle | |||
| before_batch_size = 7 | |||
| dataset = TorchNormalDataset(num_of_data=100) | |||
| # 开启 shuffle,来检验断点重训后的第二轮的 index list 是不是重新生成的; | |||
| dataloader = DataLoader(dataset, batch_size=before_batch_size, shuffle=True) | |||
| re_batchsampler = ReproducibleBatchSampler(dataloader.batch_sampler, dataloader.batch_size, drop_last=False) | |||
| dataloader = replace_batch_sampler(dataloader, re_batchsampler) | |||
| # 将一轮的所有数据保存下来,看是否恢复的是正确的; | |||
| all_supposed_data = [] | |||
| forward_steps = 3 | |||
| iter_dataloader = iter(dataloader) | |||
| for _ in range(forward_steps): | |||
| all_supposed_data.extend(next(iter_dataloader).tolist()) | |||
| # 1. 保存状态 | |||
| _get_re_batchsampler = dataloader.batch_sampler | |||
| assert isinstance(_get_re_batchsampler, ReproducibleBatchSampler) | |||
| state = _get_re_batchsampler.state_dict() | |||
| # 2. 断点重训,重新生成一个 dataloader; | |||
| # 不改变 batch_size; | |||
| dataloader = DataLoader(dataset, batch_size=before_batch_size, shuffle=True) | |||
| re_batchsampler = ReproducibleBatchSampler(dataloader.batch_sampler, dataloader.batch_size, drop_last=False) | |||
| re_batchsampler.load_state_dict(state) | |||
| dataloader = replace_batch_sampler(dataloader, re_batchsampler) | |||
| # 先把这一轮的数据过完; | |||
| pre_index_list = dataloader.batch_sampler.state_dict()["index_list"] | |||
| while True: | |||
| try: | |||
| all_supposed_data.extend(next(iter_dataloader).tolist()) | |||
| except StopIteration: | |||
| break | |||
| assert all_supposed_data == list(pre_index_list) | |||
| # 重新开启新的一轮; | |||
| for _ in range(3): | |||
| iter_dataloader = iter(dataloader) | |||
| res = [] | |||
| while True: | |||
| try: | |||
| res.append(next(iter_dataloader)) | |||
| except StopIteration: | |||
| break | |||
| def test_3(self): | |||
| import torch | |||
| from torch.utils.data import DataLoader, RandomSampler, BatchSampler | |||
| before_batch_size = 7 | |||
| dataset = TorchNormalDataset(num_of_data=100) | |||
| # 开启 shuffle,来检验断点重训后的第二轮的 index list 是不是重新生成的; | |||
| dataloader = DataLoader(dataset, batch_size=before_batch_size) | |||
| for idx, data in enumerate(dataloader): | |||
| if idx > 3: | |||
| break | |||
| iterator = iter(dataloader) | |||
| for each in iterator: | |||
| pass | |||