|
|
|
@@ -1,4 +1,4 @@ |
|
|
|
# Copyright 2019 Huawei Technologies Co., Ltd |
|
|
|
# Copyright 2019-2021 Huawei Technologies Co., Ltd |
|
|
|
# |
|
|
|
# Licensed under the Apache License, Version 2.0 (the "License"); |
|
|
|
# you may not use this file except in compliance with the License. |
|
|
|
@@ -2708,7 +2708,7 @@ class ConcatDataset(Dataset): |
|
|
|
|
|
|
|
self.dataset_size = None |
|
|
|
|
|
|
|
self._sampler = _select_sampler(None, sampler, None, None, None) |
|
|
|
self._sampler = samplers.select_sampler(None, sampler, None, None, None) |
|
|
|
cumulative_samples_nums = 0 |
|
|
|
for index, child in enumerate(self.children): |
|
|
|
if hasattr(child, 'sampler') and child.sampler.get_num_samples() is not None: |
|
|
|
@@ -2990,65 +2990,6 @@ class RangeDataset(MappableDataset): |
|
|
|
return self.dataset_size |
|
|
|
|
|
|
|
|
|
|
|
def _select_sampler(num_samples, input_sampler, shuffle, num_shards, shard_id, non_mappable=False): |
|
|
|
""" |
|
|
|
Create sampler based on user input. |
|
|
|
|
|
|
|
Args: |
|
|
|
num_samples (int): Number of samples. |
|
|
|
input_sampler (Union[Iterable, Sampler]): Sampler from user. |
|
|
|
shuffle (bool): Shuffle. |
|
|
|
num_shards (int): Number of shard for sharding. |
|
|
|
shard_id (int): Shard ID. |
|
|
|
non_mappable (bool, optional): Indicate if caller is non-mappable dataset for special handling (default=False). |
|
|
|
|
|
|
|
Returns: |
|
|
|
Sampler, sampler selected based on user input. |
|
|
|
""" |
|
|
|
if non_mappable is True and all(arg is None for arg in [num_samples, shuffle, num_shards, shard_id, input_sampler]): |
|
|
|
return None |
|
|
|
|
|
|
|
if input_sampler is not None: |
|
|
|
# If the user provided a sampler, then it doesn't matter what the other args are because |
|
|
|
# we are being asked specifically to use the given sampler. |
|
|
|
# That means the following arguments: num_shards, shard_id, shuffle, num_samples should all |
|
|
|
# be None. Consider this example: |
|
|
|
# sampler = ds.DistributedSampler(num_shards=8, shard_id=3, shuffle=shuffle) |
|
|
|
# data1 = ds.VOCDataset(voc_dir, decode=True, sampler=sampler, num_shards=4, shard_id=1) |
|
|
|
# In this case, the user has given different sample-related arguments that contradict each other. |
|
|
|
# To prevent this, only allow the user to manually specify the sampler if those arguments are all None |
|
|
|
if (isinstance(input_sampler, (samplers.SequentialSampler, samplers.DistributedSampler, |
|
|
|
samplers.RandomSampler, samplers.SubsetRandomSampler, |
|
|
|
samplers.WeightedRandomSampler, samplers.Sampler)) and |
|
|
|
(any(arg is not None for arg in [num_shards, shard_id, shuffle, num_samples]))): |
|
|
|
raise ValueError( |
|
|
|
'Conflicting arguments during sampler assignments. num_samples: {}, num_shards: {},' |
|
|
|
' shard_id: {}, shuffle: {}.'.format(num_samples, num_shards, shard_id, shuffle)) |
|
|
|
return input_sampler |
|
|
|
if shuffle is None: |
|
|
|
if num_shards is not None: |
|
|
|
# If shuffle is not specified, sharding enabled, use distributed random sampler |
|
|
|
shuffle = True |
|
|
|
return samplers.DistributedSampler(num_shards, shard_id, shuffle=shuffle, num_samples=num_samples) |
|
|
|
# If shuffle is not specified, sharding disabled, use random sampler |
|
|
|
if num_samples is not None: |
|
|
|
return samplers.RandomSampler(replacement=True, num_samples=num_samples) |
|
|
|
return samplers.RandomSampler(num_samples=num_samples) |
|
|
|
if shuffle is True: |
|
|
|
if num_shards is not None: |
|
|
|
# If shuffle enabled, sharding enabled, use distributed random sampler |
|
|
|
return samplers.DistributedSampler(num_shards, shard_id, shuffle=shuffle, num_samples=num_samples) |
|
|
|
# If shuffle enabled, sharding disabled, use random sampler |
|
|
|
if num_samples is not None: |
|
|
|
return samplers.RandomSampler(replacement=True, num_samples=num_samples) |
|
|
|
return samplers.RandomSampler(num_samples=num_samples) |
|
|
|
if num_shards is not None: |
|
|
|
# If shuffle disabled, sharding enabled, use distributed sequential sampler |
|
|
|
return samplers.DistributedSampler(num_shards, shard_id, shuffle=shuffle, num_samples=num_samples) |
|
|
|
# If shuffle disabled, sharding disabled, use sequential sampler |
|
|
|
return samplers.SequentialSampler(num_samples=num_samples) |
|
|
|
|
|
|
|
|
|
|
|
class ImageFolderDataset(MappableDataset): |
|
|
|
""" |
|
|
|
A source dataset that reads images from a tree of directories. |
|
|
|
@@ -3144,7 +3085,7 @@ class ImageFolderDataset(MappableDataset): |
|
|
|
super().__init__(num_parallel_workers=num_parallel_workers) |
|
|
|
|
|
|
|
self.dataset_dir = dataset_dir |
|
|
|
self.sampler = _select_sampler(num_samples, sampler, shuffle, num_shards, shard_id) |
|
|
|
self.sampler = samplers.select_sampler(num_samples, sampler, shuffle, num_shards, shard_id) |
|
|
|
self.num_samples = num_samples |
|
|
|
self.shuffle_level = shuffle |
|
|
|
self.extensions = replace_none(extensions, []) |
|
|
|
@@ -3293,7 +3234,7 @@ class MnistDataset(MappableDataset): |
|
|
|
|
|
|
|
self.dataset_dir = dataset_dir |
|
|
|
self.usage = replace_none(usage, "all") |
|
|
|
self.sampler = _select_sampler(num_samples, sampler, shuffle, num_shards, shard_id) |
|
|
|
self.sampler = samplers.select_sampler(num_samples, sampler, shuffle, num_shards, shard_id) |
|
|
|
self.num_samples = num_samples |
|
|
|
self.shuffle_level = shuffle |
|
|
|
self.num_shards = num_shards |
|
|
|
@@ -3386,7 +3327,7 @@ class MindDataset(MappableDataset): |
|
|
|
samplers.SequentialSampler)) is False: |
|
|
|
raise ValueError("The sampler is not supported yet.") |
|
|
|
|
|
|
|
self.sampler = _select_sampler(num_samples, sampler, shuffle, num_shards, shard_id) |
|
|
|
self.sampler = samplers.select_sampler(num_samples, sampler, shuffle, num_shards, shard_id) |
|
|
|
self.num_samples = num_samples |
|
|
|
|
|
|
|
self.padded_sample = padded_sample |
|
|
|
@@ -3470,27 +3411,6 @@ def _generator_fn(generator, num_samples): |
|
|
|
yield val |
|
|
|
|
|
|
|
|
|
|
|
def _py_sampler_fn(sampler, num_samples, dataset): |
|
|
|
""" |
|
|
|
Generator function wrapper for mappable dataset with Python sampler. |
|
|
|
""" |
|
|
|
if num_samples is not None: |
|
|
|
sampler_iter = iter(sampler) |
|
|
|
for _ in range(num_samples): |
|
|
|
try: |
|
|
|
idx = next(sampler_iter) |
|
|
|
except StopIteration: |
|
|
|
return |
|
|
|
val = dataset[idx] |
|
|
|
# convert output tensors to ndarrays |
|
|
|
yield tuple([np.array(x, copy=False) for x in val]) |
|
|
|
else: |
|
|
|
for i in sampler: |
|
|
|
val = dataset[i] |
|
|
|
# convert output tensors to ndarrays |
|
|
|
yield tuple([np.array(x, copy=False) for x in val]) |
|
|
|
|
|
|
|
|
|
|
|
def _cpp_sampler_fn(sample_ids, dataset): |
|
|
|
""" |
|
|
|
Generator function wrapper for mappable dataset with cpp sampler. |
|
|
|
@@ -3518,31 +3438,6 @@ def _cpp_sampler_fn_mp(sample_ids, sample_fn): |
|
|
|
return sample_fn.process(sample_ids) |
|
|
|
|
|
|
|
|
|
|
|
def _py_sampler_fn_mp(sampler, num_samples, sample_fn): |
|
|
|
""" |
|
|
|
Multiprocessing generator function wrapper for mappable dataset with Python sampler. |
|
|
|
""" |
|
|
|
indices = _fetch_py_sampler_indices(sampler, num_samples) |
|
|
|
return sample_fn.process(indices) |
|
|
|
|
|
|
|
|
|
|
|
def _fetch_py_sampler_indices(sampler, num_samples): |
|
|
|
""" |
|
|
|
Indice fetcher for Python sampler. |
|
|
|
""" |
|
|
|
if num_samples is not None: |
|
|
|
sampler_iter = iter(sampler) |
|
|
|
ret = [] |
|
|
|
for _ in range(num_samples): |
|
|
|
try: |
|
|
|
val = next(sampler_iter) |
|
|
|
ret.append(val) |
|
|
|
except StopIteration: |
|
|
|
break |
|
|
|
return ret |
|
|
|
return [i for i in sampler] |
|
|
|
|
|
|
|
|
|
|
|
def _fill_worker_indices(workers, indices, idx): |
|
|
|
""" |
|
|
|
Worker index queue filler, fill worker index queue in round robin order. |
|
|
|
@@ -3865,7 +3760,7 @@ class GeneratorDataset(MappableDataset): |
|
|
|
python_multiprocessing=True): |
|
|
|
super().__init__(num_parallel_workers=num_parallel_workers) |
|
|
|
self.source = source |
|
|
|
self.sampler = _select_sampler(num_samples, sampler, shuffle, num_shards, shard_id) |
|
|
|
self.sampler = samplers.select_sampler(num_samples, sampler, shuffle, num_shards, shard_id) |
|
|
|
self.num_samples = num_samples |
|
|
|
self.num_shards = num_shards |
|
|
|
self.python_multiprocessing = python_multiprocessing |
|
|
|
@@ -3912,26 +3807,11 @@ class GeneratorDataset(MappableDataset): |
|
|
|
if hasattr(self, "__total_batch__"): |
|
|
|
new_op.__total_batch__ = self.__total_batch__ |
|
|
|
if new_op.sampler is not None and hasattr(self.source, "__getitem__"): |
|
|
|
if isinstance(new_op.sampler, samplers.BuiltinSampler): |
|
|
|
if new_op.num_parallel_workers > 1: |
|
|
|
sample_fn = SamplerFn(self.source, new_op.num_parallel_workers, self.python_multiprocessing) |
|
|
|
new_op.source = (lambda sample_ids: _cpp_sampler_fn_mp(sample_ids, sample_fn)) |
|
|
|
else: |
|
|
|
new_op.source = (lambda sample_ids: _cpp_sampler_fn(sample_ids, self.source)) |
|
|
|
if new_op.num_parallel_workers > 1: |
|
|
|
sample_fn = SamplerFn(self.source, new_op.num_parallel_workers, self.python_multiprocessing) |
|
|
|
new_op.source = (lambda sample_ids: _cpp_sampler_fn_mp(sample_ids, sample_fn)) |
|
|
|
else: |
|
|
|
# the sampler provided is not a built-in sampler, it is a list of sample_ids |
|
|
|
new_op.sample_ids = new_op.sampler |
|
|
|
# since list of sample_ids are not passed to c++, we need to find the proper len here |
|
|
|
new_op.source_len = min(self.source_len, len(new_op.sample_ids)) if self.source_len != -1 else len( |
|
|
|
new_op.sample_ids) |
|
|
|
new_op.source_len = min(self.source_len, |
|
|
|
new_op.num_samples) if new_op.num_samples is not None else new_op.source_len |
|
|
|
new_op.sampler = None |
|
|
|
if new_op.num_parallel_workers > 1: |
|
|
|
sample_fn = SamplerFn(self.source, new_op.num_parallel_workers, self.python_multiprocessing) |
|
|
|
new_op.source = (lambda: _py_sampler_fn_mp(new_op.sample_ids, new_op.num_samples, sample_fn)) |
|
|
|
else: |
|
|
|
new_op.source = (lambda: _py_sampler_fn(new_op.sample_ids, new_op.num_samples, self.source)) |
|
|
|
new_op.source = (lambda sample_ids: _cpp_sampler_fn(sample_ids, self.source)) |
|
|
|
new_op.sample_fn = sample_fn |
|
|
|
else: |
|
|
|
try: |
|
|
|
@@ -4089,13 +3969,6 @@ class TFRecordDataset(SourceDataset): |
|
|
|
self.shuffle_level = shuffle |
|
|
|
self.shuffle_files = True |
|
|
|
|
|
|
|
# The TF record dataset does not directly support a sampler. It has provided sampling arguments |
|
|
|
# (shuffle, num_samples, num_shards, shard_id) and it DOES support sampling if somewhere above it in |
|
|
|
# the pipeline contains a cache. If there is no cache above it, then this sampler is not used. |
|
|
|
sampler_shuffle = self.shuffle_files |
|
|
|
sampler = None |
|
|
|
self.sampler = _select_sampler(self.num_samples, sampler, sampler_shuffle, num_shards, shard_id, |
|
|
|
non_mappable=True) |
|
|
|
self.shard_equal_rows = replace_none(shard_equal_rows, False) |
|
|
|
|
|
|
|
def get_args(self): |
|
|
|
@@ -4231,7 +4104,7 @@ class ManifestDataset(MappableDataset): |
|
|
|
super().__init__(num_parallel_workers=num_parallel_workers) |
|
|
|
|
|
|
|
self.dataset_file = dataset_file |
|
|
|
self.sampler = _select_sampler(num_samples, sampler, shuffle, num_shards, shard_id) |
|
|
|
self.sampler = samplers.select_sampler(num_samples, sampler, shuffle, num_shards, shard_id) |
|
|
|
|
|
|
|
if class_indexing is not None and not isinstance(class_indexing, dict): |
|
|
|
raise RuntimeError("class_indexing must be a dictionary.") |
|
|
|
@@ -4396,7 +4269,7 @@ class Cifar10Dataset(MappableDataset): |
|
|
|
|
|
|
|
self.dataset_dir = dataset_dir |
|
|
|
self.usage = replace_none(usage, "all") |
|
|
|
self.sampler = _select_sampler(num_samples, sampler, shuffle, num_shards, shard_id) |
|
|
|
self.sampler = samplers.select_sampler(num_samples, sampler, shuffle, num_shards, shard_id) |
|
|
|
self.num_samples = num_samples |
|
|
|
self.num_shards = num_shards |
|
|
|
self.shard_id = shard_id |
|
|
|
@@ -4535,7 +4408,7 @@ class Cifar100Dataset(MappableDataset): |
|
|
|
|
|
|
|
self.dataset_dir = dataset_dir |
|
|
|
self.usage = replace_none(usage, "all") |
|
|
|
self.sampler = _select_sampler(num_samples, sampler, shuffle, num_shards, shard_id) |
|
|
|
self.sampler = samplers.select_sampler(num_samples, sampler, shuffle, num_shards, shard_id) |
|
|
|
self.num_samples = num_samples |
|
|
|
self.num_shards = num_shards |
|
|
|
self.shard_id = shard_id |
|
|
|
@@ -4607,8 +4480,6 @@ class RandomDataset(SourceDataset): |
|
|
|
super().__init__(num_parallel_workers=num_parallel_workers) |
|
|
|
self.schema = schema |
|
|
|
self.columns_list = replace_none(columns_list, []) |
|
|
|
sampler = None |
|
|
|
self.sampler = _select_sampler(num_samples, sampler, shuffle, num_shards, shard_id, non_mappable=True) |
|
|
|
|
|
|
|
self.num_samples = num_samples |
|
|
|
self.total_rows = total_rows |
|
|
|
@@ -4900,7 +4771,7 @@ class VOCDataset(MappableDataset): |
|
|
|
self.task = replace_none(task, "Segmentation") |
|
|
|
self.usage = replace_none(usage, "train") |
|
|
|
self.class_indexing = class_indexing |
|
|
|
self.sampler = _select_sampler(num_samples, sampler, shuffle, num_shards, shard_id) |
|
|
|
self.sampler = samplers.select_sampler(num_samples, sampler, shuffle, num_shards, shard_id) |
|
|
|
self.num_samples = num_samples |
|
|
|
self.decode = replace_none(decode, False) |
|
|
|
self.shuffle_level = shuffle |
|
|
|
@@ -5092,7 +4963,7 @@ class CocoDataset(MappableDataset): |
|
|
|
self.dataset_dir = dataset_dir |
|
|
|
self.annotation_file = annotation_file |
|
|
|
self.task = replace_none(task, "Detection") |
|
|
|
self.sampler = _select_sampler(num_samples, sampler, shuffle, num_shards, shard_id) |
|
|
|
self.sampler = samplers.select_sampler(num_samples, sampler, shuffle, num_shards, shard_id) |
|
|
|
self.num_samples = num_samples |
|
|
|
self.decode = replace_none(decode, False) |
|
|
|
self.shuffle_level = shuffle |
|
|
|
@@ -5224,7 +5095,7 @@ class CelebADataset(MappableDataset): |
|
|
|
extensions=None, num_samples=None, num_shards=None, shard_id=None, cache=None): |
|
|
|
super().__init__(num_parallel_workers=num_parallel_workers) |
|
|
|
self.dataset_dir = dataset_dir |
|
|
|
self.sampler = _select_sampler(num_samples, sampler, shuffle, num_shards, shard_id) |
|
|
|
self.sampler = samplers.select_sampler(num_samples, sampler, shuffle, num_shards, shard_id) |
|
|
|
self.num_parallel_workers = num_parallel_workers |
|
|
|
self.decode = replace_none(decode, False) |
|
|
|
self.extensions = replace_none(extensions, []) |
|
|
|
@@ -5596,12 +5467,7 @@ class CSVDataset(SourceDataset): |
|
|
|
self.shuffle_files = True |
|
|
|
|
|
|
|
self.cache = cache |
|
|
|
# The CSV dataset does not directly support a sampler. It has provided sampling arguments |
|
|
|
# (shuffle, num_samples, num_shards, shard_id) and it DOES support sampling if somewhere above it in |
|
|
|
# the pipeline contains a cache. If there is no cache above it, then this sampler is not used. |
|
|
|
sampler = None |
|
|
|
self.sampler = _select_sampler(num_samples, sampler, self.shuffle_files, num_shards, shard_id, |
|
|
|
non_mappable=True) |
|
|
|
|
|
|
|
self.num_shards = replace_none(num_shards, 1) |
|
|
|
self.shard_id = replace_none(shard_id, 0) |
|
|
|
self.num_samples = replace_none(num_samples, 0) |
|
|
|
@@ -5715,13 +5581,6 @@ class TextFileDataset(SourceDataset): |
|
|
|
self.shard_id = replace_none(shard_id, 0) |
|
|
|
|
|
|
|
self.cache = cache |
|
|
|
# The text file dataset does not directly support a sampler. It has provided sampling arguments |
|
|
|
# (shuffle, num_samples, num_shards, shard_id) and it DOES support sampling if somewhere above it in |
|
|
|
# the pipeline contains a cache. If there is no cache above it, then this sampler is not used. |
|
|
|
sampler_shuffle = self.shuffle_files |
|
|
|
sampler = None |
|
|
|
self.sampler = _select_sampler(num_samples, sampler, sampler_shuffle, num_shards, shard_id, |
|
|
|
non_mappable=True) |
|
|
|
|
|
|
|
def get_args(self): |
|
|
|
args = super().get_args() |
|
|
|
|