| @@ -2310,6 +2310,7 @@ class ConcatDataset(DatasetOp): | |||
| Raises: | |||
| TypeError: If dataset is not an instance of Dataset. | |||
| ValueError: If there is no samples in the one of the datasets. | |||
| """ | |||
| def __init__(self, datasets): | |||
| @@ -2324,15 +2325,19 @@ class ConcatDataset(DatasetOp): | |||
| data.parent.append(self) | |||
| self.children_sizes_ = [c.get_dataset_size() for c in self.children] | |||
| """ | |||
| _children_flag_and_nums: A list of pair<int ,int>.The first element of pair is flag that characterizes | |||
| whether the data set is mappable. The second element of pair is length of the dataset | |||
| """ | |||
| child_index = 0 | |||
| for item in self.children_sizes_: | |||
| if item == 0: | |||
| raise ValueError("There is no samples in the %dth dataset. Please make sure there are " | |||
| "valid samples in the dataset" % child_index) | |||
| child_index += 1 | |||
| # _children_flag_and_nums: A list of pair<int ,int>.The first element of pair is flag that characterizes | |||
| # whether the data set is mappable. The second element of pair is length of the dataset | |||
| self._children_flag_and_nums = [] | |||
| """ | |||
| _children_start_end_index_: A list of pair<int ,int>.The elements of pair are used to characterize | |||
| the valid position of the dataset corresponding to the subscript when sampling | |||
| """ | |||
| # _children_start_end_index_: A list of pair<int ,int>.The elements of pair are used to characterize | |||
| # the valid position of the dataset corresponding to the subscript when sampling | |||
| self._children_start_end_index_ = [] | |||
| for index, child in enumerate(self.children): | |||
| tem_list = [-1, -1] | |||
| @@ -1,4 +1,5 @@ | |||
| from io import BytesIO | |||
| import copy | |||
| import os | |||
| import numpy as np | |||
| import pytest | |||
| @@ -412,6 +413,46 @@ def test_Mindrecord_Padded(remove_mindrecord_file): | |||
| result_list.append(tem_list) | |||
| assert result_list == verify_list | |||
| def test_clue_padded_and_skip_with_0_samples(): | |||
| """ | |||
| Test num_samples param of CLUE dataset | |||
| """ | |||
| TRAIN_FILE = '../data/dataset/testCLUE/afqmc/train.json' | |||
| data = ds.CLUEDataset(TRAIN_FILE, task='AFQMC', usage='train') | |||
| count = 0 | |||
| for _ in data.create_dict_iterator(): | |||
| count += 1 | |||
| assert count == 3 | |||
| data_copy1 = copy.deepcopy(data) | |||
| sample = {"label": np.array(1, np.string_), | |||
| "sentence1": np.array(1, np.string_), | |||
| "sentence2": np.array(1, np.string_)} | |||
| samples = [sample] | |||
| padded_ds = ds.PaddedDataset(samples) | |||
| dataset = data + padded_ds | |||
| testsampler = ds.DistributedSampler(num_shards=2, shard_id=1, shuffle=False, num_samples=None) | |||
| dataset.use_sampler(testsampler) | |||
| assert dataset.get_dataset_size() == 2 | |||
| count = 0 | |||
| for data in dataset.create_dict_iterator(): | |||
| count += 1 | |||
| assert count == 2 | |||
| dataset = dataset.skip(count=2) # dataset2 has none samples | |||
| count = 0 | |||
| for data in dataset.create_dict_iterator(): | |||
| count += 1 | |||
| assert count == 0 | |||
| with pytest.raises(ValueError, match="There is no samples in the "): | |||
| dataset = dataset.concat(data_copy1) | |||
| count = 0 | |||
| for data in dataset.create_dict_iterator(): | |||
| count += 1 | |||
| assert count == 2 | |||
| if __name__ == '__main__': | |||
| test_TFRecord_Padded() | |||