From 3e31ac6d62a115d97eb9b4ab60d6212ed8e0e28c Mon Sep 17 00:00:00 2001 From: anzhengqi Date: Wed, 12 Aug 2020 15:35:47 +0800 Subject: [PATCH] all Dataset support get_dataset_size --- mindspore/dataset/engine/datasets.py | 339 ++++++++++-------- mindspore/dataset/engine/iterators.py | 3 + mindspore/dataset/transforms/c_transforms.py | 3 +- .../dataset/test_bucket_batch_by_length.py | 20 ++ .../python/dataset/test_datasets_generator.py | 113 ++++++ tests/ut/python/dataset/test_filterop.py | 11 + 6 files changed, 336 insertions(+), 153 deletions(-) diff --git a/mindspore/dataset/engine/datasets.py b/mindspore/dataset/engine/datasets.py index 2a615e01f7..aa895e70b8 100644 --- a/mindspore/dataset/engine/datasets.py +++ b/mindspore/dataset/engine/datasets.py @@ -143,7 +143,7 @@ class Dataset: self._input_indexs = () self._output_types = None self._output_shapes = None - self._dataset_size = None + self.dataset_size = None self._batch_size = None self._num_classes = None self._repeat_count = None @@ -1189,8 +1189,6 @@ class Dataset: device_iter = TupleIterator(self) self._output_shapes = device_iter.get_output_shapes() self._output_types = device_iter.get_output_types() - if self._dataset_size is None: - self._dataset_size = device_iter.get_dataset_size() self._batch_size = device_iter.get_batch_size() self._num_classes = device_iter.num_classes() self._repeat_count = device_iter.get_repeat_count() @@ -1225,9 +1223,10 @@ class Dataset: Return: Number, number of batches. """ - if self.children: - return self.children[0].get_dataset_size() - return None + if self.dataset_size is None: + if self.children: + self.dataset_size = self.children[0].get_dataset_size() + return self.dataset_size def num_classes(self): """ @@ -1378,6 +1377,8 @@ class MappableDataset(SourceDataset): def add_sampler(self, new_sampler): # note: by adding a sampler, we mean that the sampled ids will flow to new_sampler # after first passing through the current samplers attached to this dataset. + if self.dataset_size is not None: + self.dataset_size = None new_sampler.add_child(self.sampler) self.sampler = new_sampler @@ -1406,6 +1407,8 @@ class MappableDataset(SourceDataset): raise TypeError("Input sampler can not be None.") if not isinstance(new_sampler, (samplers.BuiltinSampler, samplers.Sampler)): raise TypeError("Input sampler is not an instance of a sampler.") + if self.dataset_size is not None: + self.dataset_size = None self.sampler = self.sampler.child_sampler self.add_sampler(new_sampler) @@ -1505,6 +1508,7 @@ class MappableDataset(SourceDataset): current_split_start_index = 0 for size in absolute_sizes: ds = copy.deepcopy(self) + ds.dataset_size = None if randomize: # want to shuffle the same way every epoch before split, we are assuming # that the user will call set_seed @@ -1582,7 +1586,12 @@ class BucketBatchByLengthDataset(DatasetOp): Return: Number, number of batches. """ - return None + if self.dataset_size is None: + num_rows = 0 + for _ in self.create_dict_iterator(): + num_rows += 1 + self.dataset_size = num_rows + return self.dataset_size class BatchDataset(DatasetOp): @@ -1643,12 +1652,14 @@ class BatchDataset(DatasetOp): Return: Number, number of batches. """ - child_size = self.children[0].get_dataset_size() - if child_size is not None and isinstance(self.batch_size, int): - if self.drop_remainder: - return math.floor(child_size / self.batch_size) - return math.ceil(child_size / self.batch_size) - return None + if self.dataset_size is None: + child_size = self.children[0].get_dataset_size() + if child_size is not None and isinstance(self.batch_size, int): + if self.drop_remainder: + self.dataset_size = math.floor(child_size / self.batch_size) + else: + self.dataset_size = math.ceil(child_size / self.batch_size) + return self.dataset_size def get_batch_size(self): """ @@ -2000,7 +2011,9 @@ class MapDataset(DatasetOp): Return: Number, number of batches. """ - return self.children[0].get_dataset_size() + if self.dataset_size is None: + self.dataset_size = self.children[0].get_dataset_size() + return self.dataset_size def __deepcopy__(self, memodict): if id(self) in memodict: @@ -2019,6 +2032,7 @@ class MapDataset(DatasetOp): new_op.python_multiprocessing = copy.deepcopy(self.python_multiprocessing, memodict) new_op.cache = copy.deepcopy(self.cache, memodict) new_op.operations = self.operations + new_op.dataset_size = self.dataset_size return new_op # Iterator bootstrap will be called on iterator construction. @@ -2091,11 +2105,16 @@ class FilterDataset(DatasetOp): def get_dataset_size(self): """ Get the number of batches in an epoch. - the size cannot be determined before we run the pipeline. + Return: - 0 + Number, num of batches. """ - return 0 + if self.dataset_size is None: + num_rows = 0 + for _ in self.create_dict_iterator(): + num_rows += 1 + self.dataset_size = num_rows + return self.dataset_size class RepeatDataset(DatasetOp): @@ -2129,10 +2148,11 @@ class RepeatDataset(DatasetOp): Return: Number, number of batches. """ - child_size = self.children[0].get_dataset_size() - if child_size is not None: - return child_size * self.count - return None + if self.dataset_size is None: + child_size = self.children[0].get_dataset_size() + if child_size is not None: + self.dataset_size = child_size * self.count + return self.dataset_size def get_repeat_count(self): """ @@ -2172,11 +2192,12 @@ class SkipDataset(DatasetOp): Return: Number, number of batches. """ - child_size = self.children[0].get_dataset_size() - output_size = 0 - if self.count >= 0 and self.count < child_size: - output_size = child_size - self.count - return output_size + if self.dataset_size is None: + child_size = self.children[0].get_dataset_size() + self.dataset_size = 0 + if self.count >= 0 and self.count < child_size: + self.dataset_size = child_size - self.count + return self.dataset_size class TakeDataset(DatasetOp): @@ -2207,10 +2228,13 @@ class TakeDataset(DatasetOp): Return: Number, number of batches. """ - child_size = self.children[0].get_dataset_size() - if child_size < self.count: - return child_size - return self.count + if self.dataset_size is None: + child_size = self.children[0].get_dataset_size() + if child_size < self.count: + self.dataset_size = child_size + else: + self.dataset_size = self.count + return self.dataset_size class ZipDataset(DatasetOp): @@ -2241,10 +2265,11 @@ class ZipDataset(DatasetOp): Return: Number, number of batches. """ - children_sizes = [c.get_dataset_size() for c in self.children] - if all(c is not None for c in children_sizes): - return min(children_sizes) - return None + if self.dataset_size is None: + children_sizes = [c.get_dataset_size() for c in self.children] + if all(c is not None for c in children_sizes): + self.dataset_size = min(children_sizes) + return self.dataset_size def num_classes(self): """ @@ -2291,9 +2316,10 @@ class ConcatDataset(DatasetOp): Return: Number, number of batches. """ - children_sizes = [c.get_dataset_size() for c in self.children] - dataset_size = sum(children_sizes) - return dataset_size + if self.dataset_size is None: + children_sizes = [c.get_dataset_size() for c in self.children] + self.dataset_size = sum(children_sizes) + return self.dataset_size class RenameDataset(DatasetOp): @@ -2439,6 +2465,11 @@ class RangeDataset(MappableDataset): def is_sharded(self): return False + def get_dataset_size(self): + if self.dataset_size is None: + self.dataset_size = math.ceil((self.stop - self.start)/self.step) + return self.dataset_size + def _select_sampler(num_samples, input_sampler, shuffle, num_shards, shard_id, non_mappable=False): """ @@ -2617,14 +2648,13 @@ class ImageFolderDatasetV2(MappableDataset): Return: Number, number of batches. """ - num_rows = ImageFolderOp.get_num_rows_and_classes(self.dataset_dir)[0] - rows_per_shard = get_num_rows(num_rows, self.num_shards) - rows_from_sampler = self._get_sampler_dataset_size() - - if rows_from_sampler is None: - return rows_per_shard - - return min(rows_from_sampler, rows_per_shard) + if self.dataset_size is None: + num_rows = ImageFolderOp.get_num_rows_and_classes(self.dataset_dir)[0] + self.dataset_size = get_num_rows(num_rows, self.num_shards) + rows_from_sampler = self._get_sampler_dataset_size() + if rows_from_sampler is not None and rows_from_sampler < self.dataset_size: + self.dataset_size = rows_from_sampler + return self.dataset_size def num_classes(self): """ @@ -2758,14 +2788,13 @@ class MnistDataset(MappableDataset): Return: Number, number of batches. """ - num_rows = MnistOp.get_num_rows(self.dataset_dir) - rows_per_shard = get_num_rows(num_rows, self.num_shards) - rows_from_sampler = self._get_sampler_dataset_size() - - if rows_from_sampler is None: - return rows_per_shard - - return min(rows_from_sampler, rows_per_shard) + if self.dataset_size is None: + num_rows = MnistOp.get_num_rows(self.dataset_dir) + self.dataset_size = get_num_rows(num_rows, self.num_shards) + rows_from_sampler = self._get_sampler_dataset_size() + if rows_from_sampler is not None and rows_from_sampler < self.dataset_size: + self.dataset_size = rows_from_sampler + return self.dataset_size def is_shuffled(self): if self.shuffle_level is None: @@ -2868,20 +2897,20 @@ class MindDataset(MappableDataset): Return: Number, number of batches. """ - if self._dataset_size is None: + if self.dataset_size is None: if self.load_dataset: dataset_file = [self.dataset_file] else: dataset_file = self.dataset_file num_rows = MindRecordOp.get_num_rows(dataset_file, self.load_dataset, self.sampler, self.num_padded) - return num_rows - return self._dataset_size + self.dataset_size = num_rows + return self.dataset_size # manually set dataset_size as a tempoary solution. def set_dataset_size(self, value): logger.warning("WARN_DEPRECATED: This method is deprecated. Please use get_dataset_size directly.") if value >= 0: - self._dataset_size = value + self.dataset_size = value else: raise ValueError('Set dataset_size with negative value {}'.format(value)) @@ -3205,6 +3234,7 @@ class GeneratorDataset(MappableDataset): self.source = source self.sampler = _select_sampler(num_samples, sampler, shuffle, num_shards, shard_id) self.num_samples = num_samples + self.num_shards = num_shards if column_names is not None and not isinstance(column_names, list): column_names = [column_names] @@ -3225,9 +3255,6 @@ class GeneratorDataset(MappableDataset): self.column_names.append(col["name"]) self.column_types.append(DataType(col["type"])) - if source is not None and hasattr(source, "__len__"): - self._dataset_size = len(source) - def get_args(self): args = super().get_args() args["source"] = self.source @@ -3242,19 +3269,27 @@ class GeneratorDataset(MappableDataset): Return: Number, number of batches. """ - rows_from_sampler = self._get_sampler_dataset_size() - - if rows_from_sampler is None: - return self._dataset_size - if self._dataset_size is None: - return None + if self.dataset_size is None: + if hasattr(self.source, "__len__"): + if not self.num_shards: + self.dataset_size = len(self.source) + else: + self.dataset_size = math.ceil(len(self.source)/self.num_shards) - return min(rows_from_sampler, self._dataset_size) + rows_from_sampler = self._get_sampler_dataset_size() + if rows_from_sampler is not None and rows_from_sampler < self.dataset_size: + self.dataset_size = rows_from_sampler + else: + num_rows = 0 + for _ in self.create_dict_iterator(): + num_rows += 1 + self.dataset_size = num_rows + return self.dataset_size # manually set dataset_size as a temporary solution. def set_dataset_size(self, value): if value >= 0: - self._dataset_size = value + self.dataset_size = value else: raise ValueError('Set dataset_size with negative value {}'.format(value)) @@ -3271,6 +3306,7 @@ class GeneratorDataset(MappableDataset): new_op.column_types = copy.deepcopy(self.column_types, memodict) new_op.column_names = copy.deepcopy(self.column_names, memodict) new_op.num_samples = copy.deepcopy(self.num_samples, memodict) + new_op.dataset_size = self.dataset_size new_op.sampler = copy.deepcopy(self.sampler) if new_op.sampler is not None and hasattr(self.source, "__getitem__"): @@ -3433,19 +3469,18 @@ class TFRecordDataset(SourceDataset): Return: Number, number of batches. """ - if self._dataset_size is None: + if self.dataset_size is None: num_rows = TFReaderOp.get_num_rows(self.dataset_files, 8, estimate) - num_rows = get_num_rows(num_rows, self.num_shards) - if self.num_samples is None: - return num_rows - return min(self.num_samples, num_rows) - return self._dataset_size + self.dataset_size = get_num_rows(num_rows, self.num_shards) + if self.num_samples is not None and self.num_samples < self.dataset_size: + self.dataset_size = self.num_samples + return self.dataset_size # manually set dataset_size as a tempoary solution. def set_dataset_size(self, value): logger.warning("WARN_DEPRECATED: This method is deprecated. Please use get_dataset_size directly.") if value >= 0: - self._dataset_size = value + self.dataset_size = value else: raise ValueError('Set dataset_size with negative value {}'.format(value)) @@ -3574,19 +3609,19 @@ class ManifestDataset(MappableDataset): Return: Number, number of batches. """ - if self.class_indexing is None: - class_indexing = dict() - else: - class_indexing = self.class_indexing - - num_rows = ManifestOp.get_num_rows_and_classes(self.dataset_file, class_indexing, self.usage)[0] - rows_per_shard = get_num_rows(num_rows, self.num_shards) - rows_from_sampler = self._get_sampler_dataset_size() + if self.dataset_size is None: + if self.class_indexing is None: + class_indexing = dict() + else: + class_indexing = self.class_indexing - if rows_from_sampler is None: - return rows_per_shard + num_rows = ManifestOp.get_num_rows_and_classes(self.dataset_file, class_indexing, self.usage)[0] + self.dataset_size = get_num_rows(num_rows, self.num_shards) + rows_from_sampler = self._get_sampler_dataset_size() - return min(rows_from_sampler, rows_per_shard) + if rows_from_sampler is not None and rows_from_sampler < self.dataset_size: + self.dataset_size = rows_from_sampler + return self.dataset_size def num_classes(self): """ @@ -3742,15 +3777,15 @@ class Cifar10Dataset(MappableDataset): Return: Number, number of batches. """ + if self.dataset_size is None: + num_rows = CifarOp.get_num_rows(self.dataset_dir, True) + self.dataset_size = get_num_rows(num_rows, self.num_shards) + rows_from_sampler = self._get_sampler_dataset_size() - num_rows = CifarOp.get_num_rows(self.dataset_dir, True) - rows_per_shard = get_num_rows(num_rows, self.num_shards) - rows_from_sampler = self._get_sampler_dataset_size() - - if rows_from_sampler is None: - return rows_per_shard + if rows_from_sampler is not None and rows_from_sampler < self.dataset_size: + self.dataset_size = rows_from_sampler - return min(rows_from_sampler, rows_per_shard) + return self.dataset_size def is_shuffled(self): if self.shuffle_level is None: @@ -3878,15 +3913,15 @@ class Cifar100Dataset(MappableDataset): Return: Number, number of batches. """ + if self.dataset_size is None: + num_rows = CifarOp.get_num_rows(self.dataset_dir, False) + self.dataset_size = get_num_rows(num_rows, self.num_shards) + rows_from_sampler = self._get_sampler_dataset_size() - num_rows = CifarOp.get_num_rows(self.dataset_dir, False) - rows_per_shard = get_num_rows(num_rows, self.num_shards) - rows_from_sampler = self._get_sampler_dataset_size() - - if rows_from_sampler is None: - return rows_per_shard + if rows_from_sampler is not None and rows_from_sampler < self.dataset_size: + self.dataset_size = rows_from_sampler - return min(rows_from_sampler, rows_per_shard) + return self.dataset_size def is_shuffled(self): if self.shuffle_level is None: @@ -3971,16 +4006,16 @@ class RandomDataset(SourceDataset): Return: Number, number of batches. """ + if self.dataset_size is None: + num_rows = CifarOp.get_num_rows(self.dataset_dir, True) - num_rows = CifarOp.get_num_rows(self.dataset_dir, True) - - rows_per_shard = get_num_rows(num_rows, self.num_shards) - rows_from_sampler = self._get_sampler_dataset_size() + self.dataset_size = get_num_rows(num_rows, self.num_shards) + rows_from_sampler = self._get_sampler_dataset_size() - if rows_from_sampler is None: - return rows_per_shard + if rows_from_sampler is not None and rows_from_sampler < self.dataset_size: + self.dataset_size = rows_from_sampler - return min(rows_from_sampler, rows_per_shard) + return self.dataset_size def is_shuffled(self): if self.shuffle_level is None: @@ -4317,24 +4352,25 @@ class VOCDataset(MappableDataset): Return: Number, number of batches. """ - if self.num_samples is None: - num_samples = 0 - else: - num_samples = self.num_samples + if self.dataset_size is None: + if self.num_samples is None: + num_samples = 0 + else: + num_samples = self.num_samples - if self.class_indexing is None: - class_indexing = dict() - else: - class_indexing = self.class_indexing + if self.class_indexing is None: + class_indexing = dict() + else: + class_indexing = self.class_indexing - num_rows = VOCOp.get_num_rows(self.dataset_dir, self.task, self.mode, class_indexing, num_samples) - rows_per_shard = get_num_rows(num_rows, self.num_shards) - rows_from_sampler = self._get_sampler_dataset_size() + num_rows = VOCOp.get_num_rows(self.dataset_dir, self.task, self.mode, class_indexing, num_samples) + self.dataset_size = get_num_rows(num_rows, self.num_shards) + rows_from_sampler = self._get_sampler_dataset_size() - if rows_from_sampler is None: - return rows_per_shard + if rows_from_sampler is not None and rows_from_sampler < self.dataset_size: + self.dataset_size = rows_from_sampler - return min(rows_from_sampler, rows_per_shard) + return self.dataset_size def get_class_indexing(self): """ @@ -4514,14 +4550,15 @@ class CocoDataset(MappableDataset): Return: Number, number of batches. """ - num_rows = CocoOp.get_num_rows(self.dataset_dir, self.annotation_file, self.task) - rows_per_shard = get_num_rows(num_rows, self.num_shards) - rows_from_sampler = self._get_sampler_dataset_size() + if self.dataset_size is None: + num_rows = CocoOp.get_num_rows(self.dataset_dir, self.annotation_file, self.task) + self.dataset_size = get_num_rows(num_rows, self.num_shards) + rows_from_sampler = self._get_sampler_dataset_size() - if rows_from_sampler is None: - return rows_per_shard + if rows_from_sampler is not None and rows_from_sampler < self.dataset_size: + self.dataset_size = rows_from_sampler - return min(rows_from_sampler, rows_per_shard) + return self.dataset_size def get_class_indexing(self): """ @@ -4638,7 +4675,7 @@ class CelebADataset(MappableDataset): Return: Number, number of batches. """ - if self._dataset_size is None: + if self.dataset_size is None: dir = os.path.realpath(self.dataset_dir) attr_file = os.path.join(dir, "list_attr_celeba.txt") num_rows = '' @@ -4649,14 +4686,13 @@ class CelebADataset(MappableDataset): raise RuntimeError("attr_file not found.") except BaseException: raise RuntimeError("Get dataset size failed from attribution file.") - rows_per_shard = get_num_rows(num_rows, self.num_shards) - if self.num_samples is not None: - rows_per_shard = min(self.num_samples, rows_per_shard) + self.dataset_size = get_num_rows(num_rows, self.num_shards) + if self.num_samples is not None and self.num_samples < self.dataset_size: + self.dataset_size = self.num_samples rows_from_sampler = self._get_sampler_dataset_size() - if rows_from_sampler is None: - return rows_per_shard - return min(rows_from_sampler, rows_per_shard) - return self._dataset_size + if rows_from_sampler is not None and rows_from_sampler < self.dataset_size: + self.dataset_size = rows_from_sampler + return self.dataset_size def is_shuffled(self): if self.shuffle_level is None: @@ -4888,13 +4924,12 @@ class CLUEDataset(SourceDataset): Return: Number, number of batches. """ - if self._dataset_size is None: + if self.dataset_size is None: num_rows = ClueOp.get_num_rows(self.dataset_files) - num_rows = get_num_rows(num_rows, self.num_shards) - if self.num_samples is None: - return num_rows - return min(self.num_samples, num_rows) - return self._dataset_size + self.dataset_size = get_num_rows(num_rows, self.num_shards) + if self.num_samples is not None and self.num_samples < self.dataset_size: + self.dataset_size = self.num_samples + return self.dataset_size def is_shuffled(self): return self.shuffle_files @@ -4991,13 +5026,12 @@ class CSVDataset(SourceDataset): Return: Number, number of batches. """ - if self._dataset_size is None: + if self.dataset_size is None: num_rows = CsvOp.get_num_rows(self.dataset_files, self.column_names is None) - num_rows = get_num_rows(num_rows, self.num_shards) - if self.num_samples == -1: - return num_rows - return min(self.num_samples, num_rows) - return self._dataset_size + self.dataset_size = get_num_rows(num_rows, self.num_shards) + if self.num_samples != -1 and self.num_samples < self.dataset_size: + self.dataset_size = num_rows + return self.dataset_size def is_shuffled(self): return self.shuffle_files @@ -5082,15 +5116,14 @@ class TextFileDataset(SourceDataset): Return: Number, number of batches. """ - if self._dataset_size is None: + if self.dataset_size is None: num_rows = TextFileOp.get_num_rows(self.dataset_files) - num_rows = get_num_rows(num_rows, self.num_shards) + self.dataset_size = get_num_rows(num_rows, self.num_shards) # If the user gave a num samples in the dataset, then the sampler will limit the rows returned # to that amount. Account for that here in the row count if self.num_samples is not None and self.num_samples > 0 and num_rows > self.num_samples: - num_rows = self.num_samples - return num_rows - return self._dataset_size + self.dataset_size = self.num_samples + return self.dataset_size def is_shuffled(self): return self.shuffle_files @@ -5308,6 +5341,7 @@ class BuildVocabDataset(DatasetOp): new_op.vocab = self.vocab new_op.special_tokens = copy.deepcopy(self.special_tokens) new_op.special_first = copy.deepcopy(self.special_first) + new_op.dataset_size = self.dataset_size return new_op @@ -5365,4 +5399,5 @@ class BuildSentencePieceVocabDataset(DatasetOp): new_op.params = copy.deepcopy(self.params, memodict) new_op.vocab = self.vocab new_op.model_type = copy.deepcopy(self.model_type) + new_op.dataset_size = self.dataset_size return new_op diff --git a/mindspore/dataset/engine/iterators.py b/mindspore/dataset/engine/iterators.py index e2d810b29a..5755b79207 100644 --- a/mindspore/dataset/engine/iterators.py +++ b/mindspore/dataset/engine/iterators.py @@ -72,6 +72,7 @@ class Iterator: ITERATORS_LIST.append(weakref.ref(self)) # create a copy of tree and work on it. self.dataset = copy.deepcopy(dataset) + self.ori_dataset = dataset self.parent_subtree = [] # The dataset passed into the iterator is not the root of the tree. @@ -247,6 +248,8 @@ class Iterator: if not data: if self._index == 0: logger.warning("No records available.") + if self.ori_dataset.dataset_size is None: + self.ori_dataset.dataset_size = self._index raise StopIteration self._index += 1 return data diff --git a/mindspore/dataset/transforms/c_transforms.py b/mindspore/dataset/transforms/c_transforms.py index 0f3c8d145b..d00387a36e 100644 --- a/mindspore/dataset/transforms/c_transforms.py +++ b/mindspore/dataset/transforms/c_transforms.py @@ -31,7 +31,8 @@ class OneHot(cde.OneHotOp): Tensor operation to apply one hot encoding. Args: - num_classes (int): Number of classes of the label, it should be bigger than feature size. + num_classes (int): Number of classes of the label + it should be bigger than or equal to label class number. Raises: RuntimeError: feature size is bigger than num_classes. diff --git a/tests/ut/python/dataset/test_bucket_batch_by_length.py b/tests/ut/python/dataset/test_bucket_batch_by_length.py index 405b874110..fb0d1bc25e 100644 --- a/tests/ut/python/dataset/test_bucket_batch_by_length.py +++ b/tests/ut/python/dataset/test_bucket_batch_by_length.py @@ -382,6 +382,25 @@ def test_bucket_batch_multi_column(): assert same_shape_output == same_shape_expected_output assert variable_shape_output == variable_shape_expected_output +def test_bucket_batch_get_dataset_size(): + dataset = ds.GeneratorDataset((lambda: generate_sequential_same_shape(10)), ["col1"]) + + column_names = ["col1"] + bucket_boundaries = [1, 2, 3] + bucket_batch_sizes = [3, 3, 2, 2] + element_length_function = (lambda x: x[0] % 4) + + dataset = dataset.bucket_batch_by_length(column_names, bucket_boundaries, + bucket_batch_sizes, element_length_function) + + data_size = dataset.get_dataset_size() + + num_rows = 0 + for _ in dataset.create_dict_iterator(): + num_rows += 1 + + assert data_size == num_rows + if __name__ == '__main__': test_bucket_batch_invalid_input() @@ -394,3 +413,4 @@ if __name__ == '__main__': test_bucket_batch_drop_remainder() test_bucket_batch_default_length_function() test_bucket_batch_multi_column() + test_bucket_batch_get_dataset_size() diff --git a/tests/ut/python/dataset/test_datasets_generator.py b/tests/ut/python/dataset/test_datasets_generator.py index 5f3bc998f3..512f746b89 100644 --- a/tests/ut/python/dataset/test_datasets_generator.py +++ b/tests/ut/python/dataset/test_datasets_generator.py @@ -25,6 +25,16 @@ def generator_1d(): for i in range(64): yield (np.array([i]),) +class DatasetGenerator: + def __init__(self): + pass + + def __getitem__(self, item): + return (np.array([item]),) + + def __len__(self): + return 10 + def test_generator_0(): """ @@ -615,6 +625,103 @@ def test_generator_schema(): type_tester_with_type_check_2c_schema(np_types[i], [de_types[i], de_types[i]]) +def test_generator_dataset_size_0(): + """ + Test GeneratorDataset get_dataset_size by iterator method. + """ + logger.info("Test 1D Generator : 0 - 63 get_dataset_size") + + data1 = ds.GeneratorDataset(generator_1d, ["data"]) + data_size = data1.get_dataset_size() + + num_rows = 0 + for _ in data1.create_dict_iterator(): # each data is a dictionary + num_rows = num_rows + 1 + assert data_size == num_rows + + +def test_generator_dataset_size_1(): + """ + Test GeneratorDataset get_dataset_size by __len__ method. + """ + logger.info("Test DatasetGenerator get_dataset_size") + + dataset_generator = DatasetGenerator() + data1 = ds.GeneratorDataset(dataset_generator, ["data"]) + + data_size = data1.get_dataset_size() + + num_rows = 0 + for _ in data1.create_dict_iterator(): + num_rows = num_rows + 1 + assert data_size == num_rows + + +def test_generator_dataset_size_2(): + """ + Test GeneratorDataset + repeat get_dataset_size + """ + logger.info("Test 1D Generator + repeat get_dataset_size") + + data1 = ds.GeneratorDataset(generator_1d, ["data"]) + data1 = data1.repeat(2) + + data_size = data1.get_dataset_size() + + num_rows = 0 + for _ in data1.create_dict_iterator(): + num_rows = num_rows + 1 + assert data_size == num_rows + + +def test_generator_dataset_size_3(): + """ + Test GeneratorDataset + batch get_dataset_size + """ + logger.info("Test 1D Generator + batch get_dataset_size") + + data1 = ds.GeneratorDataset(generator_1d, ["data"]) + data1 = data1.batch(4) + + data_size = data1.get_dataset_size() + + num_rows = 0 + for _ in data1.create_dict_iterator(): + num_rows += 1 + assert data_size == num_rows + + +def test_generator_dataset_size_4(): + """ + Test GeneratorDataset + num_shards + """ + logger.info("Test 1D Generator : 0 - 63 + num_shards get_dataset_size") + + dataset_generator = DatasetGenerator() + data1 = ds.GeneratorDataset(dataset_generator, ["data"], num_shards=3, shard_id=0) + data_size = data1.get_dataset_size() + + num_rows = 0 + for _ in data1.create_dict_iterator(): # each data is a dictionary + num_rows = num_rows + 1 + assert data_size == num_rows + +def test_generator_dataset_size_5(): + """ + Test get_dataset_size after create_dict_iterator + """ + logger.info("Test get_dataset_size after create_dict_iterator") + + dataset_generator = DatasetGenerator() + data1 = ds.GeneratorDataset(dataset_generator, ["data"], num_shards=3, shard_id=0) + + num_rows = 0 + for _ in data1.create_dict_iterator(): # each data is a dictionary + num_rows = num_rows + 1 + data_size = data1.get_dataset_size() + assert data_size == num_rows + + def manual_test_generator_keyboard_interrupt(): """ Test keyboard_interrupt @@ -663,3 +770,9 @@ if __name__ == "__main__": test_generator_num_samples() test_generator_num_samples_underflow() test_generator_schema() + test_generator_dataset_size_0() + test_generator_dataset_size_1() + test_generator_dataset_size_2() + test_generator_dataset_size_3() + test_generator_dataset_size_4() + test_generator_dataset_size_5() diff --git a/tests/ut/python/dataset/test_filterop.py b/tests/ut/python/dataset/test_filterop.py index 876278571d..1c99b7864b 100644 --- a/tests/ut/python/dataset/test_filterop.py +++ b/tests/ut/python/dataset/test_filterop.py @@ -484,6 +484,16 @@ def test_filter_by_generator_with_map_all_sort(): assert ret_data[0]["col1"] == 0 assert ret_data[9]["col6"] == 509 +def test_filter_by_generator_get_dataset_size(): + dataset = ds.GeneratorDataset(generator_1d, ["data"]) + dataset = dataset.filter(predicate=filter_func_shuffle_after, num_parallel_workers=4) + data_sie = dataset.get_dataset_size() + + num_iter = 0 + for _ in dataset.create_dict_iterator(): + num_iter += 1 + assert data_sie == num_iter + if __name__ == '__main__': test_diff_predicate_func() @@ -506,3 +516,4 @@ if __name__ == '__main__': test_filter_by_generator_with_zip() test_filter_by_generator_with_zip_after() test_filter_by_generator_Partial() + test_filter_by_generator_get_dataset_size()