|
|
|
@@ -25,6 +25,7 @@ import os |
|
|
|
import random |
|
|
|
import uuid |
|
|
|
import multiprocessing |
|
|
|
import queue |
|
|
|
from enum import Enum |
|
|
|
from importlib import import_module |
|
|
|
|
|
|
|
@@ -2124,6 +2125,142 @@ def _cpp_sampler_fn(sampler, dataset): |
|
|
|
yield tuple([np.array(x) for x in val]) |
|
|
|
|
|
|
|
|
|
|
|
def _cpp_sampler_fn_mp(sampler, dataset, num_worker): |
|
|
|
""" |
|
|
|
Multiprocessing generator function wrapper for mappable dataset with cpp sampler |
|
|
|
""" |
|
|
|
indices = sampler.get_indices() |
|
|
|
return _sampler_fn_mp(indices, dataset, num_worker) |
|
|
|
|
|
|
|
|
|
|
|
def _py_sampler_fn_mp(sampler, num_samples, dataset, num_worker): |
|
|
|
""" |
|
|
|
Multiprocessing generator function wrapper for mappable dataset with python sampler |
|
|
|
""" |
|
|
|
indices = _fetch_py_sampler_indices(sampler, num_samples) |
|
|
|
return _sampler_fn_mp(indices, dataset, num_worker) |
|
|
|
|
|
|
|
|
|
|
|
def _fetch_py_sampler_indices(sampler, num_samples): |
|
|
|
""" |
|
|
|
Indices fetcher for python sampler |
|
|
|
""" |
|
|
|
if num_samples is not None: |
|
|
|
sampler_iter = iter(sampler) |
|
|
|
ret = [] |
|
|
|
for _ in range(num_samples): |
|
|
|
try: |
|
|
|
val = next(sampler_iter) |
|
|
|
ret.append(val) |
|
|
|
except StopIteration: |
|
|
|
break |
|
|
|
return ret |
|
|
|
return [i for i in sampler] |
|
|
|
|
|
|
|
|
|
|
|
def _fill_worker_indices(workers, indices, idx): |
|
|
|
""" |
|
|
|
Worker index queue filler, fill worker index queue in round robin order |
|
|
|
""" |
|
|
|
num_worker = len(workers) |
|
|
|
while idx < len(indices): |
|
|
|
try: |
|
|
|
workers[idx % num_worker].put(indices[idx]) |
|
|
|
idx += 1 |
|
|
|
except queue.Full: |
|
|
|
break |
|
|
|
return idx |
|
|
|
|
|
|
|
|
|
|
|
def _sampler_fn_mp(indices, dataset, num_worker): |
|
|
|
""" |
|
|
|
Multiprocessing generator function wrapper master process |
|
|
|
""" |
|
|
|
workers = [] |
|
|
|
# Event for end of epoch |
|
|
|
eoe = multiprocessing.Event() |
|
|
|
|
|
|
|
# Create workers |
|
|
|
for _ in range(num_worker): |
|
|
|
worker = _GeneratorWorker(dataset, eoe) |
|
|
|
worker.daemon = True |
|
|
|
workers.append(worker) |
|
|
|
|
|
|
|
# Fill initial index queues |
|
|
|
idx_cursor = 0 |
|
|
|
idx_cursor = _fill_worker_indices(workers, indices, idx_cursor) |
|
|
|
|
|
|
|
# Start all workers |
|
|
|
for w in workers: |
|
|
|
w.start() |
|
|
|
|
|
|
|
# Fetch results |
|
|
|
for i in range(len(indices)): |
|
|
|
# Fetch result and put index |
|
|
|
try: |
|
|
|
result = workers[i % num_worker].get() |
|
|
|
except queue.Empty: |
|
|
|
raise Exception("Generator worker process timeout") |
|
|
|
except KeyboardInterrupt: |
|
|
|
for w in workers: |
|
|
|
w.terminate() |
|
|
|
w.join() |
|
|
|
raise Exception("Generator worker receives KeyboardInterrupt") |
|
|
|
if idx_cursor < len(indices): |
|
|
|
idx_cursor = _fill_worker_indices(workers, indices, idx_cursor) |
|
|
|
# Set eoe event once all indices are sent |
|
|
|
if idx_cursor == len(indices) and not eoe.is_set(): |
|
|
|
eoe.set() |
|
|
|
yield tuple([np.array(x) for x in result]) |
|
|
|
|
|
|
|
|
|
|
|
def _generator_worker_loop(dataset, idx_queue, result_queue, eoe): |
|
|
|
""" |
|
|
|
Multiprocessing generator worker process loop |
|
|
|
""" |
|
|
|
while True: |
|
|
|
# Fetch index, block |
|
|
|
try: |
|
|
|
idx = idx_queue.get() |
|
|
|
except KeyboardInterrupt: |
|
|
|
raise Exception("Generator worker receives KeyboardInterrupt") |
|
|
|
if idx is None: |
|
|
|
# When the queue is out of scope from master process, a None item can be fetched from the queue. |
|
|
|
# Upon receiving None, worker process should check if EOE is set. |
|
|
|
assert eoe.is_set(), "" |
|
|
|
return |
|
|
|
# Fetch data, any exception from __getitem__ will terminate worker and timeout master process |
|
|
|
result = dataset[idx] |
|
|
|
# Send data, block |
|
|
|
try: |
|
|
|
result_queue.put(result) |
|
|
|
except KeyboardInterrupt: |
|
|
|
raise Exception("Generator worker receives KeyboardInterrupt") |
|
|
|
del result, idx |
|
|
|
|
|
|
|
|
|
|
|
class _GeneratorWorker(multiprocessing.Process): |
|
|
|
""" |
|
|
|
Worker process for multiprocess Generator |
|
|
|
""" |
|
|
|
def __init__(self, dataset, eoe): |
|
|
|
self.idx_queue = multiprocessing.Queue(16) |
|
|
|
self.res_queue = multiprocessing.Queue(16) |
|
|
|
super().__init__(target=_generator_worker_loop, args=(dataset, self.idx_queue, self.res_queue, eoe)) |
|
|
|
|
|
|
|
def put(self, item): |
|
|
|
""" |
|
|
|
Put function for worker index queue. Never block. Raise queue.Full on failure. |
|
|
|
""" |
|
|
|
self.idx_queue.put_nowait(item) |
|
|
|
|
|
|
|
def get(self): |
|
|
|
""" |
|
|
|
Get function for worker result queue. Block with timeout. |
|
|
|
""" |
|
|
|
return self.res_queue.get(timeout=5) |
|
|
|
|
|
|
|
|
|
|
|
class GeneratorDataset(SourceDataset): |
|
|
|
""" |
|
|
|
A source dataset that generate data from python by invoking python data source each epoch. |
|
|
|
@@ -2171,6 +2308,7 @@ class GeneratorDataset(SourceDataset): |
|
|
|
If the schema is not provided, the meta data from column_names and column_types is considered the schema. |
|
|
|
num_samples (int, optional): The number of samples to be included in the dataset |
|
|
|
(default=None, all images). |
|
|
|
num_parallel_workers (int, optional): Number of subprocesses used to fetch the dataset in parallel (default=1). |
|
|
|
shuffle (bool, optional): Whether or not to perform shuffle on the dataset. Random accessible input is required. |
|
|
|
(default=None, expected order behavior shown in the table). |
|
|
|
sampler (Sampler/Iterable, optional): Object used to choose samples from the dataset. Random accessible input is |
|
|
|
@@ -2229,9 +2367,15 @@ class GeneratorDataset(SourceDataset): |
|
|
|
sampler_instance.set_num_rows(len(source)) |
|
|
|
sampler_instance.set_num_samples(num_samples) |
|
|
|
sampler_instance.initialize() |
|
|
|
self.source = (lambda: _cpp_sampler_fn(sampler_instance, source)) |
|
|
|
if num_parallel_workers > 1: |
|
|
|
self.source = (lambda: _cpp_sampler_fn_mp(sampler_instance, source, num_parallel_workers)) |
|
|
|
else: |
|
|
|
self.source = (lambda: _cpp_sampler_fn(sampler_instance, source)) |
|
|
|
else: |
|
|
|
self.source = (lambda: _py_sampler_fn(self.sampler, num_samples, source)) |
|
|
|
if num_parallel_workers > 1: |
|
|
|
self.source = (lambda: _py_sampler_fn_mp(self.sampler, num_samples, source, num_parallel_workers)) |
|
|
|
else: |
|
|
|
self.source = (lambda: _py_sampler_fn(self.sampler, num_samples, source)) |
|
|
|
else: |
|
|
|
try: |
|
|
|
iter(source) |
|
|
|
|