| @@ -58,7 +58,7 @@ class CheckpointCallback(Callback): | |||||
| """ | """ | ||||
| def __init__(self, folder: Optional[Union[str, Path]] = None, every_n_epochs: Optional[int] = None, | def __init__(self, folder: Optional[Union[str, Path]] = None, every_n_epochs: Optional[int] = None, | ||||
| every_n_batches: Optional[int] = None, last: bool = False, topk: int = 0, | every_n_batches: Optional[int] = None, last: bool = False, topk: int = 0, | ||||
| on_exceptions: Optional[Union[BaseException, Sequence[BaseException]]] = [EarlyStopException], | |||||
| on_exceptions: Optional[Union[BaseException, Sequence[BaseException]]] = (EarlyStopException), | |||||
| monitor: Optional[Union[str, Callable]] = None, larger_better: bool = True, | monitor: Optional[Union[str, Callable]] = None, larger_better: bool = True, | ||||
| only_state_dict: bool = True, model_save_fn: Optional[Callable] = None, save_object: str = 'model', | only_state_dict: bool = True, model_save_fn: Optional[Callable] = None, save_object: str = 'model', | ||||
| save_evaluate_results=True, **kwargs): | save_evaluate_results=True, **kwargs): | ||||
| @@ -402,10 +402,10 @@ class DataSet: | |||||
| def __getattr__(self, item): | def __getattr__(self, item): | ||||
| # Not tested. Don't use !! | # Not tested. Don't use !! | ||||
| if item == "field_arrays": | |||||
| raise AttributeError | |||||
| if isinstance(item, str) and item in self.field_arrays: | if isinstance(item, str) and item in self.field_arrays: | ||||
| return self.field_arrays[item] | return self.field_arrays[item] | ||||
| else: | |||||
| raise AttributeError | |||||
| def __setstate__(self, state): | def __setstate__(self, state): | ||||
| self.__dict__ = state | self.__dict__ = state | ||||
| @@ -121,7 +121,9 @@ class UnrepeatedSortedSampler(UnrepeatedRandomSampler): | |||||
| :param kwargs: fastNLP 保留使用 | :param kwargs: fastNLP 保留使用 | ||||
| """ | """ | ||||
| def __init__(self, dataset, length:Union[str, List], **kwargs): | def __init__(self, dataset, length:Union[str, List], **kwargs): | ||||
| super().__init__(dataset=dataset, shuffle=False, seed=0, **kwargs) | |||||
| kwargs['shuffle'] = False | |||||
| kwargs['seed'] = 0 | |||||
| super().__init__(dataset=dataset, **kwargs) | |||||
| if isinstance(dataset, DataSet) and isinstance(length, str): | if isinstance(dataset, DataSet) and isinstance(length, str): | ||||
| length = dataset.get_field(length).content | length = dataset.get_field(length).content | ||||
| if not isinstance(length[0], int): | if not isinstance(length[0], int): | ||||
| @@ -141,17 +143,32 @@ class UnrepeatedSortedSampler(UnrepeatedRandomSampler): | |||||
| class UnrepeatedSequentialSampler(UnrepeatedRandomSampler): | class UnrepeatedSequentialSampler(UnrepeatedRandomSampler): | ||||
| """ | """ | ||||
| 按照顺序读取 dataset。在多卡情况下,间隔读取,例如,在两卡情况下,卡0取 [0,2,4,..], 卡1取 [1,3,5...]。 | |||||
| 按照顺序读取 dataset。 | |||||
| :param dataset: 实现了 __len__ 方法的数据容器。 | :param dataset: 实现了 __len__ 方法的数据容器。 | ||||
| :param chunk_dist: 如果为 True ,当多卡时,将不间隔索取数据;为 False ,间隔取数据。例如,假设 dataset 有 10 个 sample ,使用 | |||||
| 2 卡,如果为 True ,卡 0 拿 [0, 1, 2, 3, 4], 卡 1 拿 [5, 6, 7, 8, 9] ; 如果为 False ,则卡 0 拿 [0, 2, 4, 8, 8], 卡 | |||||
| 1 拿 [1, 3, 5, 7, 9] 。 | |||||
| :param kwargs: | :param kwargs: | ||||
| """ | """ | ||||
| def __init__(self, dataset, **kwargs): | |||||
| super(UnrepeatedSequentialSampler, self).__init__(dataset, shuffle=False, seed=0, **kwargs) | |||||
| def __init__(self, dataset, chunk_dist=False, **kwargs): | |||||
| kwargs['shuffle'] = False | |||||
| kwargs['seed'] = 0 | |||||
| super(UnrepeatedSequentialSampler, self).__init__(dataset, **kwargs) | |||||
| self.chunk_dist = chunk_dist | |||||
| def __iter__(self): | def __iter__(self): | ||||
| indices = self.generate_indices() | indices = self.generate_indices() | ||||
| indices = indices[self.rank:len(indices):self.num_replicas] | |||||
| if self.num_replicas>1: | |||||
| if self.chunk_dist: | |||||
| chunk_size = len(indices)//self.num_replicas | |||||
| start = chunk_size * self.rank | |||||
| end = chunk_size * (self.rank + 1) | |||||
| if self.rank == self.num_replicas - 1: | |||||
| end = len(indices) | |||||
| indices = indices[start:end] | |||||
| else: | |||||
| indices = indices[self.rank:len(indices):self.num_replicas] | |||||
| for index in indices: | for index in indices: | ||||
| yield index | yield index | ||||
| @@ -87,13 +87,14 @@ class TestUnrepeatedSequentialSampler: | |||||
| @pytest.mark.parametrize('num_replicas', [2, 3]) | @pytest.mark.parametrize('num_replicas', [2, 3]) | ||||
| @pytest.mark.parametrize('num_of_data', [2, 3, 4, 100]) | @pytest.mark.parametrize('num_of_data', [2, 3, 4, 100]) | ||||
| def test_multi(self, num_replicas, num_of_data): | |||||
| @pytest.mark.parametrize('chunk_dist', [True, False]) | |||||
| def test_multi(self, num_replicas, num_of_data, chunk_dist): | |||||
| if num_replicas > num_of_data: | if num_replicas > num_of_data: | ||||
| pytest.skip("num_replicas > num_of_data") | pytest.skip("num_replicas > num_of_data") | ||||
| data = DatasetWithVaryLength(num_of_data=num_of_data) | data = DatasetWithVaryLength(num_of_data=num_of_data) | ||||
| samplers = [] | samplers = [] | ||||
| for i in range(num_replicas): | for i in range(num_replicas): | ||||
| sampler = UnrepeatedSequentialSampler(dataset=data, length=data.data) | |||||
| sampler = UnrepeatedSequentialSampler(dataset=data, chunk_dist=chunk_dist) | |||||
| sampler.set_distributed(num_replicas, rank=i) | sampler.set_distributed(num_replicas, rank=i) | ||||
| samplers.append(sampler) | samplers.append(sampler) | ||||