|
|
@@ -143,7 +143,7 @@ class Dataset: |
|
|
self._input_indexs = () |
|
|
self._input_indexs = () |
|
|
self._output_types = None |
|
|
self._output_types = None |
|
|
self._output_shapes = None |
|
|
self._output_shapes = None |
|
|
self._dataset_size = None |
|
|
|
|
|
|
|
|
self.dataset_size = None |
|
|
self._batch_size = None |
|
|
self._batch_size = None |
|
|
self._num_classes = None |
|
|
self._num_classes = None |
|
|
self._repeat_count = None |
|
|
self._repeat_count = None |
|
|
@@ -1189,8 +1189,6 @@ class Dataset: |
|
|
device_iter = TupleIterator(self) |
|
|
device_iter = TupleIterator(self) |
|
|
self._output_shapes = device_iter.get_output_shapes() |
|
|
self._output_shapes = device_iter.get_output_shapes() |
|
|
self._output_types = device_iter.get_output_types() |
|
|
self._output_types = device_iter.get_output_types() |
|
|
if self._dataset_size is None: |
|
|
|
|
|
self._dataset_size = device_iter.get_dataset_size() |
|
|
|
|
|
self._batch_size = device_iter.get_batch_size() |
|
|
self._batch_size = device_iter.get_batch_size() |
|
|
self._num_classes = device_iter.num_classes() |
|
|
self._num_classes = device_iter.num_classes() |
|
|
self._repeat_count = device_iter.get_repeat_count() |
|
|
self._repeat_count = device_iter.get_repeat_count() |
|
|
@@ -1225,9 +1223,10 @@ class Dataset: |
|
|
Return: |
|
|
Return: |
|
|
Number, number of batches. |
|
|
Number, number of batches. |
|
|
""" |
|
|
""" |
|
|
if self.children: |
|
|
|
|
|
return self.children[0].get_dataset_size() |
|
|
|
|
|
return None |
|
|
|
|
|
|
|
|
if self.dataset_size is None: |
|
|
|
|
|
if self.children: |
|
|
|
|
|
self.dataset_size = self.children[0].get_dataset_size() |
|
|
|
|
|
return self.dataset_size |
|
|
|
|
|
|
|
|
def num_classes(self): |
|
|
def num_classes(self): |
|
|
""" |
|
|
""" |
|
|
@@ -1378,6 +1377,8 @@ class MappableDataset(SourceDataset): |
|
|
def add_sampler(self, new_sampler): |
|
|
def add_sampler(self, new_sampler): |
|
|
# note: by adding a sampler, we mean that the sampled ids will flow to new_sampler |
|
|
# note: by adding a sampler, we mean that the sampled ids will flow to new_sampler |
|
|
# after first passing through the current samplers attached to this dataset. |
|
|
# after first passing through the current samplers attached to this dataset. |
|
|
|
|
|
if self.dataset_size is not None: |
|
|
|
|
|
self.dataset_size = None |
|
|
new_sampler.add_child(self.sampler) |
|
|
new_sampler.add_child(self.sampler) |
|
|
self.sampler = new_sampler |
|
|
self.sampler = new_sampler |
|
|
|
|
|
|
|
|
@@ -1406,6 +1407,8 @@ class MappableDataset(SourceDataset): |
|
|
raise TypeError("Input sampler can not be None.") |
|
|
raise TypeError("Input sampler can not be None.") |
|
|
if not isinstance(new_sampler, (samplers.BuiltinSampler, samplers.Sampler)): |
|
|
if not isinstance(new_sampler, (samplers.BuiltinSampler, samplers.Sampler)): |
|
|
raise TypeError("Input sampler is not an instance of a sampler.") |
|
|
raise TypeError("Input sampler is not an instance of a sampler.") |
|
|
|
|
|
if self.dataset_size is not None: |
|
|
|
|
|
self.dataset_size = None |
|
|
|
|
|
|
|
|
self.sampler = self.sampler.child_sampler |
|
|
self.sampler = self.sampler.child_sampler |
|
|
self.add_sampler(new_sampler) |
|
|
self.add_sampler(new_sampler) |
|
|
@@ -1505,6 +1508,7 @@ class MappableDataset(SourceDataset): |
|
|
current_split_start_index = 0 |
|
|
current_split_start_index = 0 |
|
|
for size in absolute_sizes: |
|
|
for size in absolute_sizes: |
|
|
ds = copy.deepcopy(self) |
|
|
ds = copy.deepcopy(self) |
|
|
|
|
|
ds.dataset_size = None |
|
|
if randomize: |
|
|
if randomize: |
|
|
# want to shuffle the same way every epoch before split, we are assuming |
|
|
# want to shuffle the same way every epoch before split, we are assuming |
|
|
# that the user will call set_seed |
|
|
# that the user will call set_seed |
|
|
@@ -1582,7 +1586,12 @@ class BucketBatchByLengthDataset(DatasetOp): |
|
|
Return: |
|
|
Return: |
|
|
Number, number of batches. |
|
|
Number, number of batches. |
|
|
""" |
|
|
""" |
|
|
return None |
|
|
|
|
|
|
|
|
if self.dataset_size is None: |
|
|
|
|
|
num_rows = 0 |
|
|
|
|
|
for _ in self.create_dict_iterator(): |
|
|
|
|
|
num_rows += 1 |
|
|
|
|
|
self.dataset_size = num_rows |
|
|
|
|
|
return self.dataset_size |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class BatchDataset(DatasetOp): |
|
|
class BatchDataset(DatasetOp): |
|
|
@@ -1643,12 +1652,14 @@ class BatchDataset(DatasetOp): |
|
|
Return: |
|
|
Return: |
|
|
Number, number of batches. |
|
|
Number, number of batches. |
|
|
""" |
|
|
""" |
|
|
child_size = self.children[0].get_dataset_size() |
|
|
|
|
|
if child_size is not None and isinstance(self.batch_size, int): |
|
|
|
|
|
if self.drop_remainder: |
|
|
|
|
|
return math.floor(child_size / self.batch_size) |
|
|
|
|
|
return math.ceil(child_size / self.batch_size) |
|
|
|
|
|
return None |
|
|
|
|
|
|
|
|
if self.dataset_size is None: |
|
|
|
|
|
child_size = self.children[0].get_dataset_size() |
|
|
|
|
|
if child_size is not None and isinstance(self.batch_size, int): |
|
|
|
|
|
if self.drop_remainder: |
|
|
|
|
|
self.dataset_size = math.floor(child_size / self.batch_size) |
|
|
|
|
|
else: |
|
|
|
|
|
self.dataset_size = math.ceil(child_size / self.batch_size) |
|
|
|
|
|
return self.dataset_size |
|
|
|
|
|
|
|
|
def get_batch_size(self): |
|
|
def get_batch_size(self): |
|
|
""" |
|
|
""" |
|
|
@@ -2000,7 +2011,9 @@ class MapDataset(DatasetOp): |
|
|
Return: |
|
|
Return: |
|
|
Number, number of batches. |
|
|
Number, number of batches. |
|
|
""" |
|
|
""" |
|
|
return self.children[0].get_dataset_size() |
|
|
|
|
|
|
|
|
if self.dataset_size is None: |
|
|
|
|
|
self.dataset_size = self.children[0].get_dataset_size() |
|
|
|
|
|
return self.dataset_size |
|
|
|
|
|
|
|
|
def __deepcopy__(self, memodict): |
|
|
def __deepcopy__(self, memodict): |
|
|
if id(self) in memodict: |
|
|
if id(self) in memodict: |
|
|
@@ -2019,6 +2032,7 @@ class MapDataset(DatasetOp): |
|
|
new_op.python_multiprocessing = copy.deepcopy(self.python_multiprocessing, memodict) |
|
|
new_op.python_multiprocessing = copy.deepcopy(self.python_multiprocessing, memodict) |
|
|
new_op.cache = copy.deepcopy(self.cache, memodict) |
|
|
new_op.cache = copy.deepcopy(self.cache, memodict) |
|
|
new_op.operations = self.operations |
|
|
new_op.operations = self.operations |
|
|
|
|
|
new_op.dataset_size = self.dataset_size |
|
|
return new_op |
|
|
return new_op |
|
|
|
|
|
|
|
|
# Iterator bootstrap will be called on iterator construction. |
|
|
# Iterator bootstrap will be called on iterator construction. |
|
|
@@ -2091,11 +2105,16 @@ class FilterDataset(DatasetOp): |
|
|
def get_dataset_size(self): |
|
|
def get_dataset_size(self): |
|
|
""" |
|
|
""" |
|
|
Get the number of batches in an epoch. |
|
|
Get the number of batches in an epoch. |
|
|
the size cannot be determined before we run the pipeline. |
|
|
|
|
|
|
|
|
|
|
|
Return: |
|
|
Return: |
|
|
0 |
|
|
|
|
|
|
|
|
Number, num of batches. |
|
|
""" |
|
|
""" |
|
|
return 0 |
|
|
|
|
|
|
|
|
if self.dataset_size is None: |
|
|
|
|
|
num_rows = 0 |
|
|
|
|
|
for _ in self.create_dict_iterator(): |
|
|
|
|
|
num_rows += 1 |
|
|
|
|
|
self.dataset_size = num_rows |
|
|
|
|
|
return self.dataset_size |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class RepeatDataset(DatasetOp): |
|
|
class RepeatDataset(DatasetOp): |
|
|
@@ -2129,10 +2148,11 @@ class RepeatDataset(DatasetOp): |
|
|
Return: |
|
|
Return: |
|
|
Number, number of batches. |
|
|
Number, number of batches. |
|
|
""" |
|
|
""" |
|
|
child_size = self.children[0].get_dataset_size() |
|
|
|
|
|
if child_size is not None: |
|
|
|
|
|
return child_size * self.count |
|
|
|
|
|
return None |
|
|
|
|
|
|
|
|
if self.dataset_size is None: |
|
|
|
|
|
child_size = self.children[0].get_dataset_size() |
|
|
|
|
|
if child_size is not None: |
|
|
|
|
|
self.dataset_size = child_size * self.count |
|
|
|
|
|
return self.dataset_size |
|
|
|
|
|
|
|
|
def get_repeat_count(self): |
|
|
def get_repeat_count(self): |
|
|
""" |
|
|
""" |
|
|
@@ -2172,11 +2192,12 @@ class SkipDataset(DatasetOp): |
|
|
Return: |
|
|
Return: |
|
|
Number, number of batches. |
|
|
Number, number of batches. |
|
|
""" |
|
|
""" |
|
|
child_size = self.children[0].get_dataset_size() |
|
|
|
|
|
output_size = 0 |
|
|
|
|
|
if self.count >= 0 and self.count < child_size: |
|
|
|
|
|
output_size = child_size - self.count |
|
|
|
|
|
return output_size |
|
|
|
|
|
|
|
|
if self.dataset_size is None: |
|
|
|
|
|
child_size = self.children[0].get_dataset_size() |
|
|
|
|
|
self.dataset_size = 0 |
|
|
|
|
|
if self.count >= 0 and self.count < child_size: |
|
|
|
|
|
self.dataset_size = child_size - self.count |
|
|
|
|
|
return self.dataset_size |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class TakeDataset(DatasetOp): |
|
|
class TakeDataset(DatasetOp): |
|
|
@@ -2207,10 +2228,13 @@ class TakeDataset(DatasetOp): |
|
|
Return: |
|
|
Return: |
|
|
Number, number of batches. |
|
|
Number, number of batches. |
|
|
""" |
|
|
""" |
|
|
child_size = self.children[0].get_dataset_size() |
|
|
|
|
|
if child_size < self.count: |
|
|
|
|
|
return child_size |
|
|
|
|
|
return self.count |
|
|
|
|
|
|
|
|
if self.dataset_size is None: |
|
|
|
|
|
child_size = self.children[0].get_dataset_size() |
|
|
|
|
|
if child_size < self.count: |
|
|
|
|
|
self.dataset_size = child_size |
|
|
|
|
|
else: |
|
|
|
|
|
self.dataset_size = self.count |
|
|
|
|
|
return self.dataset_size |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class ZipDataset(DatasetOp): |
|
|
class ZipDataset(DatasetOp): |
|
|
@@ -2241,10 +2265,11 @@ class ZipDataset(DatasetOp): |
|
|
Return: |
|
|
Return: |
|
|
Number, number of batches. |
|
|
Number, number of batches. |
|
|
""" |
|
|
""" |
|
|
children_sizes = [c.get_dataset_size() for c in self.children] |
|
|
|
|
|
if all(c is not None for c in children_sizes): |
|
|
|
|
|
return min(children_sizes) |
|
|
|
|
|
return None |
|
|
|
|
|
|
|
|
if self.dataset_size is None: |
|
|
|
|
|
children_sizes = [c.get_dataset_size() for c in self.children] |
|
|
|
|
|
if all(c is not None for c in children_sizes): |
|
|
|
|
|
self.dataset_size = min(children_sizes) |
|
|
|
|
|
return self.dataset_size |
|
|
|
|
|
|
|
|
def num_classes(self): |
|
|
def num_classes(self): |
|
|
""" |
|
|
""" |
|
|
@@ -2291,9 +2316,10 @@ class ConcatDataset(DatasetOp): |
|
|
Return: |
|
|
Return: |
|
|
Number, number of batches. |
|
|
Number, number of batches. |
|
|
""" |
|
|
""" |
|
|
children_sizes = [c.get_dataset_size() for c in self.children] |
|
|
|
|
|
dataset_size = sum(children_sizes) |
|
|
|
|
|
return dataset_size |
|
|
|
|
|
|
|
|
if self.dataset_size is None: |
|
|
|
|
|
children_sizes = [c.get_dataset_size() for c in self.children] |
|
|
|
|
|
self.dataset_size = sum(children_sizes) |
|
|
|
|
|
return self.dataset_size |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class RenameDataset(DatasetOp): |
|
|
class RenameDataset(DatasetOp): |
|
|
@@ -2439,6 +2465,11 @@ class RangeDataset(MappableDataset): |
|
|
def is_sharded(self): |
|
|
def is_sharded(self): |
|
|
return False |
|
|
return False |
|
|
|
|
|
|
|
|
|
|
|
def get_dataset_size(self): |
|
|
|
|
|
if self.dataset_size is None: |
|
|
|
|
|
self.dataset_size = math.ceil((self.stop - self.start)/self.step) |
|
|
|
|
|
return self.dataset_size |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def _select_sampler(num_samples, input_sampler, shuffle, num_shards, shard_id, non_mappable=False): |
|
|
def _select_sampler(num_samples, input_sampler, shuffle, num_shards, shard_id, non_mappable=False): |
|
|
""" |
|
|
""" |
|
|
@@ -2617,14 +2648,13 @@ class ImageFolderDatasetV2(MappableDataset): |
|
|
Return: |
|
|
Return: |
|
|
Number, number of batches. |
|
|
Number, number of batches. |
|
|
""" |
|
|
""" |
|
|
num_rows = ImageFolderOp.get_num_rows_and_classes(self.dataset_dir)[0] |
|
|
|
|
|
rows_per_shard = get_num_rows(num_rows, self.num_shards) |
|
|
|
|
|
rows_from_sampler = self._get_sampler_dataset_size() |
|
|
|
|
|
|
|
|
|
|
|
if rows_from_sampler is None: |
|
|
|
|
|
return rows_per_shard |
|
|
|
|
|
|
|
|
|
|
|
return min(rows_from_sampler, rows_per_shard) |
|
|
|
|
|
|
|
|
if self.dataset_size is None: |
|
|
|
|
|
num_rows = ImageFolderOp.get_num_rows_and_classes(self.dataset_dir)[0] |
|
|
|
|
|
self.dataset_size = get_num_rows(num_rows, self.num_shards) |
|
|
|
|
|
rows_from_sampler = self._get_sampler_dataset_size() |
|
|
|
|
|
if rows_from_sampler is not None and rows_from_sampler < self.dataset_size: |
|
|
|
|
|
self.dataset_size = rows_from_sampler |
|
|
|
|
|
return self.dataset_size |
|
|
|
|
|
|
|
|
def num_classes(self): |
|
|
def num_classes(self): |
|
|
""" |
|
|
""" |
|
|
@@ -2758,14 +2788,13 @@ class MnistDataset(MappableDataset): |
|
|
Return: |
|
|
Return: |
|
|
Number, number of batches. |
|
|
Number, number of batches. |
|
|
""" |
|
|
""" |
|
|
num_rows = MnistOp.get_num_rows(self.dataset_dir) |
|
|
|
|
|
rows_per_shard = get_num_rows(num_rows, self.num_shards) |
|
|
|
|
|
rows_from_sampler = self._get_sampler_dataset_size() |
|
|
|
|
|
|
|
|
|
|
|
if rows_from_sampler is None: |
|
|
|
|
|
return rows_per_shard |
|
|
|
|
|
|
|
|
|
|
|
return min(rows_from_sampler, rows_per_shard) |
|
|
|
|
|
|
|
|
if self.dataset_size is None: |
|
|
|
|
|
num_rows = MnistOp.get_num_rows(self.dataset_dir) |
|
|
|
|
|
self.dataset_size = get_num_rows(num_rows, self.num_shards) |
|
|
|
|
|
rows_from_sampler = self._get_sampler_dataset_size() |
|
|
|
|
|
if rows_from_sampler is not None and rows_from_sampler < self.dataset_size: |
|
|
|
|
|
self.dataset_size = rows_from_sampler |
|
|
|
|
|
return self.dataset_size |
|
|
|
|
|
|
|
|
def is_shuffled(self): |
|
|
def is_shuffled(self): |
|
|
if self.shuffle_level is None: |
|
|
if self.shuffle_level is None: |
|
|
@@ -2868,20 +2897,20 @@ class MindDataset(MappableDataset): |
|
|
Return: |
|
|
Return: |
|
|
Number, number of batches. |
|
|
Number, number of batches. |
|
|
""" |
|
|
""" |
|
|
if self._dataset_size is None: |
|
|
|
|
|
|
|
|
if self.dataset_size is None: |
|
|
if self.load_dataset: |
|
|
if self.load_dataset: |
|
|
dataset_file = [self.dataset_file] |
|
|
dataset_file = [self.dataset_file] |
|
|
else: |
|
|
else: |
|
|
dataset_file = self.dataset_file |
|
|
dataset_file = self.dataset_file |
|
|
num_rows = MindRecordOp.get_num_rows(dataset_file, self.load_dataset, self.sampler, self.num_padded) |
|
|
num_rows = MindRecordOp.get_num_rows(dataset_file, self.load_dataset, self.sampler, self.num_padded) |
|
|
return num_rows |
|
|
|
|
|
return self._dataset_size |
|
|
|
|
|
|
|
|
self.dataset_size = num_rows |
|
|
|
|
|
return self.dataset_size |
|
|
|
|
|
|
|
|
# manually set dataset_size as a tempoary solution. |
|
|
# manually set dataset_size as a tempoary solution. |
|
|
def set_dataset_size(self, value): |
|
|
def set_dataset_size(self, value): |
|
|
logger.warning("WARN_DEPRECATED: This method is deprecated. Please use get_dataset_size directly.") |
|
|
logger.warning("WARN_DEPRECATED: This method is deprecated. Please use get_dataset_size directly.") |
|
|
if value >= 0: |
|
|
if value >= 0: |
|
|
self._dataset_size = value |
|
|
|
|
|
|
|
|
self.dataset_size = value |
|
|
else: |
|
|
else: |
|
|
raise ValueError('Set dataset_size with negative value {}'.format(value)) |
|
|
raise ValueError('Set dataset_size with negative value {}'.format(value)) |
|
|
|
|
|
|
|
|
@@ -3205,6 +3234,7 @@ class GeneratorDataset(MappableDataset): |
|
|
self.source = source |
|
|
self.source = source |
|
|
self.sampler = _select_sampler(num_samples, sampler, shuffle, num_shards, shard_id) |
|
|
self.sampler = _select_sampler(num_samples, sampler, shuffle, num_shards, shard_id) |
|
|
self.num_samples = num_samples |
|
|
self.num_samples = num_samples |
|
|
|
|
|
self.num_shards = num_shards |
|
|
|
|
|
|
|
|
if column_names is not None and not isinstance(column_names, list): |
|
|
if column_names is not None and not isinstance(column_names, list): |
|
|
column_names = [column_names] |
|
|
column_names = [column_names] |
|
|
@@ -3225,9 +3255,6 @@ class GeneratorDataset(MappableDataset): |
|
|
self.column_names.append(col["name"]) |
|
|
self.column_names.append(col["name"]) |
|
|
self.column_types.append(DataType(col["type"])) |
|
|
self.column_types.append(DataType(col["type"])) |
|
|
|
|
|
|
|
|
if source is not None and hasattr(source, "__len__"): |
|
|
|
|
|
self._dataset_size = len(source) |
|
|
|
|
|
|
|
|
|
|
|
def get_args(self): |
|
|
def get_args(self): |
|
|
args = super().get_args() |
|
|
args = super().get_args() |
|
|
args["source"] = self.source |
|
|
args["source"] = self.source |
|
|
@@ -3242,19 +3269,27 @@ class GeneratorDataset(MappableDataset): |
|
|
Return: |
|
|
Return: |
|
|
Number, number of batches. |
|
|
Number, number of batches. |
|
|
""" |
|
|
""" |
|
|
rows_from_sampler = self._get_sampler_dataset_size() |
|
|
|
|
|
|
|
|
|
|
|
if rows_from_sampler is None: |
|
|
|
|
|
return self._dataset_size |
|
|
|
|
|
if self._dataset_size is None: |
|
|
|
|
|
return None |
|
|
|
|
|
|
|
|
if self.dataset_size is None: |
|
|
|
|
|
if hasattr(self.source, "__len__"): |
|
|
|
|
|
if not self.num_shards: |
|
|
|
|
|
self.dataset_size = len(self.source) |
|
|
|
|
|
else: |
|
|
|
|
|
self.dataset_size = math.ceil(len(self.source)/self.num_shards) |
|
|
|
|
|
|
|
|
return min(rows_from_sampler, self._dataset_size) |
|
|
|
|
|
|
|
|
rows_from_sampler = self._get_sampler_dataset_size() |
|
|
|
|
|
if rows_from_sampler is not None and rows_from_sampler < self.dataset_size: |
|
|
|
|
|
self.dataset_size = rows_from_sampler |
|
|
|
|
|
else: |
|
|
|
|
|
num_rows = 0 |
|
|
|
|
|
for _ in self.create_dict_iterator(): |
|
|
|
|
|
num_rows += 1 |
|
|
|
|
|
self.dataset_size = num_rows |
|
|
|
|
|
return self.dataset_size |
|
|
|
|
|
|
|
|
# manually set dataset_size as a temporary solution. |
|
|
# manually set dataset_size as a temporary solution. |
|
|
def set_dataset_size(self, value): |
|
|
def set_dataset_size(self, value): |
|
|
if value >= 0: |
|
|
if value >= 0: |
|
|
self._dataset_size = value |
|
|
|
|
|
|
|
|
self.dataset_size = value |
|
|
else: |
|
|
else: |
|
|
raise ValueError('Set dataset_size with negative value {}'.format(value)) |
|
|
raise ValueError('Set dataset_size with negative value {}'.format(value)) |
|
|
|
|
|
|
|
|
@@ -3271,6 +3306,7 @@ class GeneratorDataset(MappableDataset): |
|
|
new_op.column_types = copy.deepcopy(self.column_types, memodict) |
|
|
new_op.column_types = copy.deepcopy(self.column_types, memodict) |
|
|
new_op.column_names = copy.deepcopy(self.column_names, memodict) |
|
|
new_op.column_names = copy.deepcopy(self.column_names, memodict) |
|
|
new_op.num_samples = copy.deepcopy(self.num_samples, memodict) |
|
|
new_op.num_samples = copy.deepcopy(self.num_samples, memodict) |
|
|
|
|
|
new_op.dataset_size = self.dataset_size |
|
|
|
|
|
|
|
|
new_op.sampler = copy.deepcopy(self.sampler) |
|
|
new_op.sampler = copy.deepcopy(self.sampler) |
|
|
if new_op.sampler is not None and hasattr(self.source, "__getitem__"): |
|
|
if new_op.sampler is not None and hasattr(self.source, "__getitem__"): |
|
|
@@ -3433,19 +3469,18 @@ class TFRecordDataset(SourceDataset): |
|
|
Return: |
|
|
Return: |
|
|
Number, number of batches. |
|
|
Number, number of batches. |
|
|
""" |
|
|
""" |
|
|
if self._dataset_size is None: |
|
|
|
|
|
|
|
|
if self.dataset_size is None: |
|
|
num_rows = TFReaderOp.get_num_rows(self.dataset_files, 8, estimate) |
|
|
num_rows = TFReaderOp.get_num_rows(self.dataset_files, 8, estimate) |
|
|
num_rows = get_num_rows(num_rows, self.num_shards) |
|
|
|
|
|
if self.num_samples is None: |
|
|
|
|
|
return num_rows |
|
|
|
|
|
return min(self.num_samples, num_rows) |
|
|
|
|
|
return self._dataset_size |
|
|
|
|
|
|
|
|
self.dataset_size = get_num_rows(num_rows, self.num_shards) |
|
|
|
|
|
if self.num_samples is not None and self.num_samples < self.dataset_size: |
|
|
|
|
|
self.dataset_size = self.num_samples |
|
|
|
|
|
return self.dataset_size |
|
|
|
|
|
|
|
|
# manually set dataset_size as a tempoary solution. |
|
|
# manually set dataset_size as a tempoary solution. |
|
|
def set_dataset_size(self, value): |
|
|
def set_dataset_size(self, value): |
|
|
logger.warning("WARN_DEPRECATED: This method is deprecated. Please use get_dataset_size directly.") |
|
|
logger.warning("WARN_DEPRECATED: This method is deprecated. Please use get_dataset_size directly.") |
|
|
if value >= 0: |
|
|
if value >= 0: |
|
|
self._dataset_size = value |
|
|
|
|
|
|
|
|
self.dataset_size = value |
|
|
else: |
|
|
else: |
|
|
raise ValueError('Set dataset_size with negative value {}'.format(value)) |
|
|
raise ValueError('Set dataset_size with negative value {}'.format(value)) |
|
|
|
|
|
|
|
|
@@ -3574,19 +3609,19 @@ class ManifestDataset(MappableDataset): |
|
|
Return: |
|
|
Return: |
|
|
Number, number of batches. |
|
|
Number, number of batches. |
|
|
""" |
|
|
""" |
|
|
if self.class_indexing is None: |
|
|
|
|
|
class_indexing = dict() |
|
|
|
|
|
else: |
|
|
|
|
|
class_indexing = self.class_indexing |
|
|
|
|
|
|
|
|
|
|
|
num_rows = ManifestOp.get_num_rows_and_classes(self.dataset_file, class_indexing, self.usage)[0] |
|
|
|
|
|
rows_per_shard = get_num_rows(num_rows, self.num_shards) |
|
|
|
|
|
rows_from_sampler = self._get_sampler_dataset_size() |
|
|
|
|
|
|
|
|
if self.dataset_size is None: |
|
|
|
|
|
if self.class_indexing is None: |
|
|
|
|
|
class_indexing = dict() |
|
|
|
|
|
else: |
|
|
|
|
|
class_indexing = self.class_indexing |
|
|
|
|
|
|
|
|
if rows_from_sampler is None: |
|
|
|
|
|
return rows_per_shard |
|
|
|
|
|
|
|
|
num_rows = ManifestOp.get_num_rows_and_classes(self.dataset_file, class_indexing, self.usage)[0] |
|
|
|
|
|
self.dataset_size = get_num_rows(num_rows, self.num_shards) |
|
|
|
|
|
rows_from_sampler = self._get_sampler_dataset_size() |
|
|
|
|
|
|
|
|
return min(rows_from_sampler, rows_per_shard) |
|
|
|
|
|
|
|
|
if rows_from_sampler is not None and rows_from_sampler < self.dataset_size: |
|
|
|
|
|
self.dataset_size = rows_from_sampler |
|
|
|
|
|
return self.dataset_size |
|
|
|
|
|
|
|
|
def num_classes(self): |
|
|
def num_classes(self): |
|
|
""" |
|
|
""" |
|
|
@@ -3742,15 +3777,15 @@ class Cifar10Dataset(MappableDataset): |
|
|
Return: |
|
|
Return: |
|
|
Number, number of batches. |
|
|
Number, number of batches. |
|
|
""" |
|
|
""" |
|
|
|
|
|
if self.dataset_size is None: |
|
|
|
|
|
num_rows = CifarOp.get_num_rows(self.dataset_dir, True) |
|
|
|
|
|
self.dataset_size = get_num_rows(num_rows, self.num_shards) |
|
|
|
|
|
rows_from_sampler = self._get_sampler_dataset_size() |
|
|
|
|
|
|
|
|
num_rows = CifarOp.get_num_rows(self.dataset_dir, True) |
|
|
|
|
|
rows_per_shard = get_num_rows(num_rows, self.num_shards) |
|
|
|
|
|
rows_from_sampler = self._get_sampler_dataset_size() |
|
|
|
|
|
|
|
|
|
|
|
if rows_from_sampler is None: |
|
|
|
|
|
return rows_per_shard |
|
|
|
|
|
|
|
|
if rows_from_sampler is not None and rows_from_sampler < self.dataset_size: |
|
|
|
|
|
self.dataset_size = rows_from_sampler |
|
|
|
|
|
|
|
|
return min(rows_from_sampler, rows_per_shard) |
|
|
|
|
|
|
|
|
return self.dataset_size |
|
|
|
|
|
|
|
|
def is_shuffled(self): |
|
|
def is_shuffled(self): |
|
|
if self.shuffle_level is None: |
|
|
if self.shuffle_level is None: |
|
|
@@ -3878,15 +3913,15 @@ class Cifar100Dataset(MappableDataset): |
|
|
Return: |
|
|
Return: |
|
|
Number, number of batches. |
|
|
Number, number of batches. |
|
|
""" |
|
|
""" |
|
|
|
|
|
if self.dataset_size is None: |
|
|
|
|
|
num_rows = CifarOp.get_num_rows(self.dataset_dir, False) |
|
|
|
|
|
self.dataset_size = get_num_rows(num_rows, self.num_shards) |
|
|
|
|
|
rows_from_sampler = self._get_sampler_dataset_size() |
|
|
|
|
|
|
|
|
num_rows = CifarOp.get_num_rows(self.dataset_dir, False) |
|
|
|
|
|
rows_per_shard = get_num_rows(num_rows, self.num_shards) |
|
|
|
|
|
rows_from_sampler = self._get_sampler_dataset_size() |
|
|
|
|
|
|
|
|
|
|
|
if rows_from_sampler is None: |
|
|
|
|
|
return rows_per_shard |
|
|
|
|
|
|
|
|
if rows_from_sampler is not None and rows_from_sampler < self.dataset_size: |
|
|
|
|
|
self.dataset_size = rows_from_sampler |
|
|
|
|
|
|
|
|
return min(rows_from_sampler, rows_per_shard) |
|
|
|
|
|
|
|
|
return self.dataset_size |
|
|
|
|
|
|
|
|
def is_shuffled(self): |
|
|
def is_shuffled(self): |
|
|
if self.shuffle_level is None: |
|
|
if self.shuffle_level is None: |
|
|
@@ -3971,16 +4006,16 @@ class RandomDataset(SourceDataset): |
|
|
Return: |
|
|
Return: |
|
|
Number, number of batches. |
|
|
Number, number of batches. |
|
|
""" |
|
|
""" |
|
|
|
|
|
if self.dataset_size is None: |
|
|
|
|
|
num_rows = CifarOp.get_num_rows(self.dataset_dir, True) |
|
|
|
|
|
|
|
|
num_rows = CifarOp.get_num_rows(self.dataset_dir, True) |
|
|
|
|
|
|
|
|
|
|
|
rows_per_shard = get_num_rows(num_rows, self.num_shards) |
|
|
|
|
|
rows_from_sampler = self._get_sampler_dataset_size() |
|
|
|
|
|
|
|
|
self.dataset_size = get_num_rows(num_rows, self.num_shards) |
|
|
|
|
|
rows_from_sampler = self._get_sampler_dataset_size() |
|
|
|
|
|
|
|
|
if rows_from_sampler is None: |
|
|
|
|
|
return rows_per_shard |
|
|
|
|
|
|
|
|
if rows_from_sampler is not None and rows_from_sampler < self.dataset_size: |
|
|
|
|
|
self.dataset_size = rows_from_sampler |
|
|
|
|
|
|
|
|
return min(rows_from_sampler, rows_per_shard) |
|
|
|
|
|
|
|
|
return self.dataset_size |
|
|
|
|
|
|
|
|
def is_shuffled(self): |
|
|
def is_shuffled(self): |
|
|
if self.shuffle_level is None: |
|
|
if self.shuffle_level is None: |
|
|
@@ -4317,24 +4352,25 @@ class VOCDataset(MappableDataset): |
|
|
Return: |
|
|
Return: |
|
|
Number, number of batches. |
|
|
Number, number of batches. |
|
|
""" |
|
|
""" |
|
|
if self.num_samples is None: |
|
|
|
|
|
num_samples = 0 |
|
|
|
|
|
else: |
|
|
|
|
|
num_samples = self.num_samples |
|
|
|
|
|
|
|
|
if self.dataset_size is None: |
|
|
|
|
|
if self.num_samples is None: |
|
|
|
|
|
num_samples = 0 |
|
|
|
|
|
else: |
|
|
|
|
|
num_samples = self.num_samples |
|
|
|
|
|
|
|
|
if self.class_indexing is None: |
|
|
|
|
|
class_indexing = dict() |
|
|
|
|
|
else: |
|
|
|
|
|
class_indexing = self.class_indexing |
|
|
|
|
|
|
|
|
if self.class_indexing is None: |
|
|
|
|
|
class_indexing = dict() |
|
|
|
|
|
else: |
|
|
|
|
|
class_indexing = self.class_indexing |
|
|
|
|
|
|
|
|
num_rows = VOCOp.get_num_rows(self.dataset_dir, self.task, self.mode, class_indexing, num_samples) |
|
|
|
|
|
rows_per_shard = get_num_rows(num_rows, self.num_shards) |
|
|
|
|
|
rows_from_sampler = self._get_sampler_dataset_size() |
|
|
|
|
|
|
|
|
num_rows = VOCOp.get_num_rows(self.dataset_dir, self.task, self.mode, class_indexing, num_samples) |
|
|
|
|
|
self.dataset_size = get_num_rows(num_rows, self.num_shards) |
|
|
|
|
|
rows_from_sampler = self._get_sampler_dataset_size() |
|
|
|
|
|
|
|
|
if rows_from_sampler is None: |
|
|
|
|
|
return rows_per_shard |
|
|
|
|
|
|
|
|
if rows_from_sampler is not None and rows_from_sampler < self.dataset_size: |
|
|
|
|
|
self.dataset_size = rows_from_sampler |
|
|
|
|
|
|
|
|
return min(rows_from_sampler, rows_per_shard) |
|
|
|
|
|
|
|
|
return self.dataset_size |
|
|
|
|
|
|
|
|
def get_class_indexing(self): |
|
|
def get_class_indexing(self): |
|
|
""" |
|
|
""" |
|
|
@@ -4514,14 +4550,15 @@ class CocoDataset(MappableDataset): |
|
|
Return: |
|
|
Return: |
|
|
Number, number of batches. |
|
|
Number, number of batches. |
|
|
""" |
|
|
""" |
|
|
num_rows = CocoOp.get_num_rows(self.dataset_dir, self.annotation_file, self.task) |
|
|
|
|
|
rows_per_shard = get_num_rows(num_rows, self.num_shards) |
|
|
|
|
|
rows_from_sampler = self._get_sampler_dataset_size() |
|
|
|
|
|
|
|
|
if self.dataset_size is None: |
|
|
|
|
|
num_rows = CocoOp.get_num_rows(self.dataset_dir, self.annotation_file, self.task) |
|
|
|
|
|
self.dataset_size = get_num_rows(num_rows, self.num_shards) |
|
|
|
|
|
rows_from_sampler = self._get_sampler_dataset_size() |
|
|
|
|
|
|
|
|
if rows_from_sampler is None: |
|
|
|
|
|
return rows_per_shard |
|
|
|
|
|
|
|
|
if rows_from_sampler is not None and rows_from_sampler < self.dataset_size: |
|
|
|
|
|
self.dataset_size = rows_from_sampler |
|
|
|
|
|
|
|
|
return min(rows_from_sampler, rows_per_shard) |
|
|
|
|
|
|
|
|
return self.dataset_size |
|
|
|
|
|
|
|
|
def get_class_indexing(self): |
|
|
def get_class_indexing(self): |
|
|
""" |
|
|
""" |
|
|
@@ -4638,7 +4675,7 @@ class CelebADataset(MappableDataset): |
|
|
Return: |
|
|
Return: |
|
|
Number, number of batches. |
|
|
Number, number of batches. |
|
|
""" |
|
|
""" |
|
|
if self._dataset_size is None: |
|
|
|
|
|
|
|
|
if self.dataset_size is None: |
|
|
dir = os.path.realpath(self.dataset_dir) |
|
|
dir = os.path.realpath(self.dataset_dir) |
|
|
attr_file = os.path.join(dir, "list_attr_celeba.txt") |
|
|
attr_file = os.path.join(dir, "list_attr_celeba.txt") |
|
|
num_rows = '' |
|
|
num_rows = '' |
|
|
@@ -4649,14 +4686,13 @@ class CelebADataset(MappableDataset): |
|
|
raise RuntimeError("attr_file not found.") |
|
|
raise RuntimeError("attr_file not found.") |
|
|
except BaseException: |
|
|
except BaseException: |
|
|
raise RuntimeError("Get dataset size failed from attribution file.") |
|
|
raise RuntimeError("Get dataset size failed from attribution file.") |
|
|
rows_per_shard = get_num_rows(num_rows, self.num_shards) |
|
|
|
|
|
if self.num_samples is not None: |
|
|
|
|
|
rows_per_shard = min(self.num_samples, rows_per_shard) |
|
|
|
|
|
|
|
|
self.dataset_size = get_num_rows(num_rows, self.num_shards) |
|
|
|
|
|
if self.num_samples is not None and self.num_samples < self.dataset_size: |
|
|
|
|
|
self.dataset_size = self.num_samples |
|
|
rows_from_sampler = self._get_sampler_dataset_size() |
|
|
rows_from_sampler = self._get_sampler_dataset_size() |
|
|
if rows_from_sampler is None: |
|
|
|
|
|
return rows_per_shard |
|
|
|
|
|
return min(rows_from_sampler, rows_per_shard) |
|
|
|
|
|
return self._dataset_size |
|
|
|
|
|
|
|
|
if rows_from_sampler is not None and rows_from_sampler < self.dataset_size: |
|
|
|
|
|
self.dataset_size = rows_from_sampler |
|
|
|
|
|
return self.dataset_size |
|
|
|
|
|
|
|
|
def is_shuffled(self): |
|
|
def is_shuffled(self): |
|
|
if self.shuffle_level is None: |
|
|
if self.shuffle_level is None: |
|
|
@@ -4888,13 +4924,12 @@ class CLUEDataset(SourceDataset): |
|
|
Return: |
|
|
Return: |
|
|
Number, number of batches. |
|
|
Number, number of batches. |
|
|
""" |
|
|
""" |
|
|
if self._dataset_size is None: |
|
|
|
|
|
|
|
|
if self.dataset_size is None: |
|
|
num_rows = ClueOp.get_num_rows(self.dataset_files) |
|
|
num_rows = ClueOp.get_num_rows(self.dataset_files) |
|
|
num_rows = get_num_rows(num_rows, self.num_shards) |
|
|
|
|
|
if self.num_samples is None: |
|
|
|
|
|
return num_rows |
|
|
|
|
|
return min(self.num_samples, num_rows) |
|
|
|
|
|
return self._dataset_size |
|
|
|
|
|
|
|
|
self.dataset_size = get_num_rows(num_rows, self.num_shards) |
|
|
|
|
|
if self.num_samples is not None and self.num_samples < self.dataset_size: |
|
|
|
|
|
self.dataset_size = self.num_samples |
|
|
|
|
|
return self.dataset_size |
|
|
|
|
|
|
|
|
def is_shuffled(self): |
|
|
def is_shuffled(self): |
|
|
return self.shuffle_files |
|
|
return self.shuffle_files |
|
|
@@ -4991,13 +5026,12 @@ class CSVDataset(SourceDataset): |
|
|
Return: |
|
|
Return: |
|
|
Number, number of batches. |
|
|
Number, number of batches. |
|
|
""" |
|
|
""" |
|
|
if self._dataset_size is None: |
|
|
|
|
|
|
|
|
if self.dataset_size is None: |
|
|
num_rows = CsvOp.get_num_rows(self.dataset_files, self.column_names is None) |
|
|
num_rows = CsvOp.get_num_rows(self.dataset_files, self.column_names is None) |
|
|
num_rows = get_num_rows(num_rows, self.num_shards) |
|
|
|
|
|
if self.num_samples == -1: |
|
|
|
|
|
return num_rows |
|
|
|
|
|
return min(self.num_samples, num_rows) |
|
|
|
|
|
return self._dataset_size |
|
|
|
|
|
|
|
|
self.dataset_size = get_num_rows(num_rows, self.num_shards) |
|
|
|
|
|
if self.num_samples != -1 and self.num_samples < self.dataset_size: |
|
|
|
|
|
self.dataset_size = num_rows |
|
|
|
|
|
return self.dataset_size |
|
|
|
|
|
|
|
|
def is_shuffled(self): |
|
|
def is_shuffled(self): |
|
|
return self.shuffle_files |
|
|
return self.shuffle_files |
|
|
@@ -5082,15 +5116,14 @@ class TextFileDataset(SourceDataset): |
|
|
Return: |
|
|
Return: |
|
|
Number, number of batches. |
|
|
Number, number of batches. |
|
|
""" |
|
|
""" |
|
|
if self._dataset_size is None: |
|
|
|
|
|
|
|
|
if self.dataset_size is None: |
|
|
num_rows = TextFileOp.get_num_rows(self.dataset_files) |
|
|
num_rows = TextFileOp.get_num_rows(self.dataset_files) |
|
|
num_rows = get_num_rows(num_rows, self.num_shards) |
|
|
|
|
|
|
|
|
self.dataset_size = get_num_rows(num_rows, self.num_shards) |
|
|
# If the user gave a num samples in the dataset, then the sampler will limit the rows returned |
|
|
# If the user gave a num samples in the dataset, then the sampler will limit the rows returned |
|
|
# to that amount. Account for that here in the row count |
|
|
# to that amount. Account for that here in the row count |
|
|
if self.num_samples is not None and self.num_samples > 0 and num_rows > self.num_samples: |
|
|
if self.num_samples is not None and self.num_samples > 0 and num_rows > self.num_samples: |
|
|
num_rows = self.num_samples |
|
|
|
|
|
return num_rows |
|
|
|
|
|
return self._dataset_size |
|
|
|
|
|
|
|
|
self.dataset_size = self.num_samples |
|
|
|
|
|
return self.dataset_size |
|
|
|
|
|
|
|
|
def is_shuffled(self): |
|
|
def is_shuffled(self): |
|
|
return self.shuffle_files |
|
|
return self.shuffle_files |
|
|
@@ -5308,6 +5341,7 @@ class BuildVocabDataset(DatasetOp): |
|
|
new_op.vocab = self.vocab |
|
|
new_op.vocab = self.vocab |
|
|
new_op.special_tokens = copy.deepcopy(self.special_tokens) |
|
|
new_op.special_tokens = copy.deepcopy(self.special_tokens) |
|
|
new_op.special_first = copy.deepcopy(self.special_first) |
|
|
new_op.special_first = copy.deepcopy(self.special_first) |
|
|
|
|
|
new_op.dataset_size = self.dataset_size |
|
|
|
|
|
|
|
|
return new_op |
|
|
return new_op |
|
|
|
|
|
|
|
|
@@ -5365,4 +5399,5 @@ class BuildSentencePieceVocabDataset(DatasetOp): |
|
|
new_op.params = copy.deepcopy(self.params, memodict) |
|
|
new_op.params = copy.deepcopy(self.params, memodict) |
|
|
new_op.vocab = self.vocab |
|
|
new_op.vocab = self.vocab |
|
|
new_op.model_type = copy.deepcopy(self.model_type) |
|
|
new_op.model_type = copy.deepcopy(self.model_type) |
|
|
|
|
|
new_op.dataset_size = self.dataset_size |
|
|
return new_op |
|
|
return new_op |