|
|
|
@@ -205,12 +205,12 @@ class Dataset: |
|
|
|
@check_sync_wait |
|
|
|
def sync_wait(self, condition_name, num_batch=1, callback=None): |
|
|
|
''' |
|
|
|
Add a blocking condition to the input Dataset |
|
|
|
Add a blocking condition to the input Dataset. |
|
|
|
|
|
|
|
Args: |
|
|
|
num_batch (int): the number of batches without blocking at the start of each epoch |
|
|
|
condition_name (str): The condition name that is used to toggle sending next row |
|
|
|
callback (function): The callback funciton that will be invoked when sync_update is called |
|
|
|
num_batch (int): the number of batches without blocking at the start of each epoch. |
|
|
|
condition_name (str): The condition name that is used to toggle sending next row. |
|
|
|
callback (function): The callback funciton that will be invoked when sync_update is called. |
|
|
|
|
|
|
|
Raises: |
|
|
|
RuntimeError: If condition name already exists. |
|
|
|
@@ -920,13 +920,13 @@ class Dataset: |
|
|
|
|
|
|
|
def sync_update(self, condition_name, num_batch=None, data=None): |
|
|
|
""" |
|
|
|
Release a blocking condition and triger callback with given data |
|
|
|
Release a blocking condition and triger callback with given data. |
|
|
|
|
|
|
|
Args: |
|
|
|
condition_name (str): The condition name that is used to toggle sending next row |
|
|
|
num_batch (int or None): The number of batches(rows) that are released |
|
|
|
When num_batch is None, it will default to the number specified by the sync_wait operator |
|
|
|
data (dict or None): The data passed to the callback |
|
|
|
condition_name (str): The condition name that is used to toggle sending next row. |
|
|
|
num_batch (int or None): The number of batches(rows) that are released. |
|
|
|
When num_batch is None, it will default to the number specified by the sync_wait operator. |
|
|
|
data (dict or None): The data passed to the callback. |
|
|
|
""" |
|
|
|
notifiers_dict = self.get_sync_notifiers() |
|
|
|
if condition_name not in notifiers_dict: |
|
|
|
@@ -948,7 +948,7 @@ class Dataset: |
|
|
|
|
|
|
|
def get_repeat_count(self): |
|
|
|
""" |
|
|
|
Get the replication times in RepeatDataset else 1 |
|
|
|
Get the replication times in RepeatDataset else 1. |
|
|
|
|
|
|
|
Return: |
|
|
|
Number, the count of repeat. |
|
|
|
@@ -969,7 +969,7 @@ class Dataset: |
|
|
|
raise NotImplementedError("Dataset {} has not supported api get_class_indexing yet.".format(type(self))) |
|
|
|
|
|
|
|
def reset(self): |
|
|
|
"""Reset the dataset for next epoch""" |
|
|
|
"""Reset the dataset for next epoch.""" |
|
|
|
|
|
|
|
|
|
|
|
class SourceDataset(Dataset): |
|
|
|
@@ -1085,9 +1085,9 @@ class BatchDataset(DatasetOp): |
|
|
|
Utility function to find the case where repeat is used before batch. |
|
|
|
|
|
|
|
Args: |
|
|
|
dataset (Dataset): dataset to be checked |
|
|
|
dataset (Dataset): dataset to be checked. |
|
|
|
Return: |
|
|
|
True or False |
|
|
|
True or False. |
|
|
|
""" |
|
|
|
if isinstance(dataset, RepeatDataset): |
|
|
|
return True |
|
|
|
@@ -1102,8 +1102,8 @@ class BatchDataset(DatasetOp): |
|
|
|
Utility function to notify batch size to sync_wait. |
|
|
|
|
|
|
|
Args: |
|
|
|
dataset (Dataset): dataset to be checked |
|
|
|
batchsize (int): batch size to notify |
|
|
|
dataset (Dataset): dataset to be checked. |
|
|
|
batchsize (int): batch size to notify. |
|
|
|
""" |
|
|
|
if isinstance(dataset, SyncWaitDataset): |
|
|
|
dataset.update_sync_batch_size(batch_size) |
|
|
|
@@ -1136,11 +1136,11 @@ class BatchInfo(CBatchInfo): |
|
|
|
|
|
|
|
class BlockReleasePair: |
|
|
|
""" |
|
|
|
The blocking condition class used by SyncWaitDataset |
|
|
|
The blocking condition class used by SyncWaitDataset. |
|
|
|
|
|
|
|
Args: |
|
|
|
init_release_rows (int): Number of lines to allow through the pipeline |
|
|
|
callback (function): The callback funciton that will be called when release is called |
|
|
|
init_release_rows (int): Number of lines to allow through the pipeline. |
|
|
|
callback (function): The callback funciton that will be called when release is called. |
|
|
|
""" |
|
|
|
def __init__(self, init_release_rows, callback=None): |
|
|
|
self.row_count = -init_release_rows |
|
|
|
@@ -1183,13 +1183,13 @@ class BlockReleasePair: |
|
|
|
|
|
|
|
class SyncWaitDataset(DatasetOp): |
|
|
|
""" |
|
|
|
The result of adding a blocking condition to the input Dataset |
|
|
|
The result of adding a blocking condition to the input Dataset. |
|
|
|
|
|
|
|
Args: |
|
|
|
input_dataset (Dataset): Input dataset to apply flow control |
|
|
|
num_batch (int): the number of batches without blocking at the start of each epoch |
|
|
|
condition_name (str): The condition name that is used to toggle sending next row |
|
|
|
callback (function): The callback funciton that will be invoked when sync_update is called |
|
|
|
input_dataset (Dataset): Input dataset to apply flow control. |
|
|
|
num_batch (int): the number of batches without blocking at the start of each epoch. |
|
|
|
condition_name (str): The condition name that is used to toggle sending next row. |
|
|
|
callback (function): The callback funciton that will be invoked when sync_update is called. |
|
|
|
|
|
|
|
Raises: |
|
|
|
RuntimeError: If condition name already exists. |
|
|
|
@@ -1226,9 +1226,9 @@ class SyncWaitDataset(DatasetOp): |
|
|
|
Utility function to find the case where sync_wait is used before batch. |
|
|
|
|
|
|
|
Args: |
|
|
|
dataset (Dataset): dataset to be checked |
|
|
|
dataset (Dataset): dataset to be checked. |
|
|
|
Return: |
|
|
|
True or False |
|
|
|
True or False. |
|
|
|
""" |
|
|
|
if isinstance(dataset, BatchDataset): |
|
|
|
return True |
|
|
|
@@ -1289,7 +1289,7 @@ def _pyfunc_worker_exec(index, *args): |
|
|
|
# PythonCallable wrapper for multiprocess pyfunc |
|
|
|
class _PythonCallable: |
|
|
|
""" |
|
|
|
Internal python function wrapper for multiprocessing pyfunc |
|
|
|
Internal python function wrapper for multiprocessing pyfunc. |
|
|
|
""" |
|
|
|
def __init__(self, py_callable, idx, pool=None): |
|
|
|
# Original python callable from user. |
|
|
|
@@ -1467,7 +1467,7 @@ class FilterDataset(DatasetOp): |
|
|
|
def get_dataset_size(self): |
|
|
|
""" |
|
|
|
Get the number of batches in an epoch. |
|
|
|
the size cannot be determined before we run the pipeline |
|
|
|
the size cannot be determined before we run the pipeline. |
|
|
|
Return: |
|
|
|
0 |
|
|
|
""" |
|
|
|
@@ -1759,7 +1759,7 @@ class StorageDataset(SourceDataset): |
|
|
|
columns_list (list[str], optional): List of columns to be read (default=None, read all columns). |
|
|
|
num_parallel_workers (int, optional): Number of parallel working threads (default=None). |
|
|
|
deterministic_output (bool, optional): Whether the result of this dataset can be reproduced |
|
|
|
or not (default=True). If True, performance might be affected. |
|
|
|
or not (default=True). If True, performance might be affected. |
|
|
|
prefetch_size (int, optional): Prefetch number of records ahead of the user's request (default=None). |
|
|
|
|
|
|
|
Raises: |
|
|
|
@@ -1889,11 +1889,11 @@ def _select_sampler(num_samples, input_sampler, shuffle, num_shards, shard_id): |
|
|
|
Create sampler based on user input. |
|
|
|
|
|
|
|
Args: |
|
|
|
num_samples (int): Number of samples |
|
|
|
input_sampler (Iterable / Sampler): Sampler from user |
|
|
|
shuffle (bool): Shuffle |
|
|
|
num_shards (int): Number of shard for sharding |
|
|
|
shard_id (int): Shard ID |
|
|
|
num_samples (int): Number of samples. |
|
|
|
input_sampler (Iterable / Sampler): Sampler from user. |
|
|
|
shuffle (bool): Shuffle. |
|
|
|
num_shards (int): Number of shard for sharding. |
|
|
|
shard_id (int): Shard ID. |
|
|
|
""" |
|
|
|
if shuffle is None: |
|
|
|
if input_sampler is not None: |
|
|
|
@@ -2265,7 +2265,7 @@ class MindDataset(SourceDataset): |
|
|
|
|
|
|
|
def _iter_fn(dataset, num_samples): |
|
|
|
""" |
|
|
|
Generator function wrapper for iterable dataset |
|
|
|
Generator function wrapper for iterable dataset. |
|
|
|
""" |
|
|
|
if num_samples is not None: |
|
|
|
ds_iter = iter(dataset) |
|
|
|
@@ -2284,7 +2284,7 @@ def _iter_fn(dataset, num_samples): |
|
|
|
|
|
|
|
def _generator_fn(generator, num_samples): |
|
|
|
""" |
|
|
|
Generator function wrapper for generator function dataset |
|
|
|
Generator function wrapper for generator function dataset. |
|
|
|
""" |
|
|
|
if num_samples is not None: |
|
|
|
gen_iter = generator() |
|
|
|
@@ -2302,7 +2302,7 @@ def _generator_fn(generator, num_samples): |
|
|
|
|
|
|
|
def _py_sampler_fn(sampler, num_samples, dataset): |
|
|
|
""" |
|
|
|
Generator function wrapper for mappable dataset with python sampler |
|
|
|
Generator function wrapper for mappable dataset with python sampler. |
|
|
|
""" |
|
|
|
if num_samples is not None: |
|
|
|
sampler_iter = iter(sampler) |
|
|
|
@@ -2323,7 +2323,7 @@ def _py_sampler_fn(sampler, num_samples, dataset): |
|
|
|
|
|
|
|
def _cpp_sampler_fn(sampler, dataset): |
|
|
|
""" |
|
|
|
Generator function wrapper for mappable dataset with cpp sampler |
|
|
|
Generator function wrapper for mappable dataset with cpp sampler. |
|
|
|
""" |
|
|
|
indices = sampler.get_indices() |
|
|
|
for i in indices: |
|
|
|
@@ -2334,7 +2334,7 @@ def _cpp_sampler_fn(sampler, dataset): |
|
|
|
|
|
|
|
def _cpp_sampler_fn_mp(sampler, dataset, num_worker): |
|
|
|
""" |
|
|
|
Multiprocessing generator function wrapper for mappable dataset with cpp sampler |
|
|
|
Multiprocessing generator function wrapper for mappable dataset with cpp sampler. |
|
|
|
""" |
|
|
|
indices = sampler.get_indices() |
|
|
|
return _sampler_fn_mp(indices, dataset, num_worker) |
|
|
|
@@ -2342,7 +2342,7 @@ def _cpp_sampler_fn_mp(sampler, dataset, num_worker): |
|
|
|
|
|
|
|
def _py_sampler_fn_mp(sampler, num_samples, dataset, num_worker): |
|
|
|
""" |
|
|
|
Multiprocessing generator function wrapper for mappable dataset with python sampler |
|
|
|
Multiprocessing generator function wrapper for mappable dataset with python sampler. |
|
|
|
""" |
|
|
|
indices = _fetch_py_sampler_indices(sampler, num_samples) |
|
|
|
return _sampler_fn_mp(indices, dataset, num_worker) |
|
|
|
@@ -2350,7 +2350,7 @@ def _py_sampler_fn_mp(sampler, num_samples, dataset, num_worker): |
|
|
|
|
|
|
|
def _fetch_py_sampler_indices(sampler, num_samples): |
|
|
|
""" |
|
|
|
Indices fetcher for python sampler |
|
|
|
Indices fetcher for python sampler. |
|
|
|
""" |
|
|
|
if num_samples is not None: |
|
|
|
sampler_iter = iter(sampler) |
|
|
|
@@ -2367,7 +2367,7 @@ def _fetch_py_sampler_indices(sampler, num_samples): |
|
|
|
|
|
|
|
def _fill_worker_indices(workers, indices, idx): |
|
|
|
""" |
|
|
|
Worker index queue filler, fill worker index queue in round robin order |
|
|
|
Worker index queue filler, fill worker index queue in round robin order. |
|
|
|
""" |
|
|
|
num_worker = len(workers) |
|
|
|
while idx < len(indices): |
|
|
|
@@ -2381,7 +2381,7 @@ def _fill_worker_indices(workers, indices, idx): |
|
|
|
|
|
|
|
def _sampler_fn_mp(indices, dataset, num_worker): |
|
|
|
""" |
|
|
|
Multiprocessing generator function wrapper master process |
|
|
|
Multiprocessing generator function wrapper master process. |
|
|
|
""" |
|
|
|
workers = [] |
|
|
|
# Event for end of epoch |
|
|
|
@@ -2423,7 +2423,7 @@ def _sampler_fn_mp(indices, dataset, num_worker): |
|
|
|
|
|
|
|
def _generator_worker_loop(dataset, idx_queue, result_queue, eoe): |
|
|
|
""" |
|
|
|
Multiprocessing generator worker process loop |
|
|
|
Multiprocessing generator worker process loop. |
|
|
|
""" |
|
|
|
while True: |
|
|
|
# Fetch index, block |
|
|
|
@@ -2448,7 +2448,7 @@ def _generator_worker_loop(dataset, idx_queue, result_queue, eoe): |
|
|
|
|
|
|
|
class _GeneratorWorker(multiprocessing.Process): |
|
|
|
""" |
|
|
|
Worker process for multiprocess Generator |
|
|
|
Worker process for multiprocess Generator. |
|
|
|
""" |
|
|
|
def __init__(self, dataset, eoe): |
|
|
|
self.idx_queue = multiprocessing.Queue(16) |
|
|
|
@@ -2932,7 +2932,7 @@ class ManifestDataset(SourceDataset): |
|
|
|
|
|
|
|
def get_class_indexing(self): |
|
|
|
""" |
|
|
|
Get the class index |
|
|
|
Get the class index. |
|
|
|
|
|
|
|
Return: |
|
|
|
Dict, A str-to-int mapping from label name to index. |
|
|
|
@@ -3500,7 +3500,7 @@ class VOCDataset(SourceDataset): |
|
|
|
|
|
|
|
class CelebADataset(SourceDataset): |
|
|
|
""" |
|
|
|
A source dataset for reading and parsing CelebA dataset.Only support list_attr_celeba.txt currently |
|
|
|
A source dataset for reading and parsing CelebA dataset.Only support list_attr_celeba.txt currently. |
|
|
|
|
|
|
|
Note: |
|
|
|
The generated dataset has two columns ['image', 'attr']. |
|
|
|
|