|
|
|
@@ -51,6 +51,7 @@ import mindspore.dataset.transforms.py_transforms as py_transforms |
|
|
|
from . import samplers |
|
|
|
from .iterators import DictIterator, TupleIterator, DummyIterator, check_iterator_cleanup, _set_iterator_cleanup, \ |
|
|
|
ITERATORS_LIST, _unset_iterator_cleanup |
|
|
|
from .queue import _SharedQueue |
|
|
|
from .validators import check_batch, check_shuffle, check_map, check_filter, check_repeat, check_skip, check_zip, \ |
|
|
|
check_rename, check_numpyslicesdataset, check_device_send, \ |
|
|
|
check_take, check_project, check_imagefolderdataset, check_mnist_cifar_dataset, check_manifestdataset, \ |
|
|
|
@@ -58,7 +59,8 @@ from .validators import check_batch, check_shuffle, check_map, check_filter, che |
|
|
|
check_generatordataset, check_sync_wait, check_zip_dataset, check_add_column, check_textfiledataset, check_concat, \ |
|
|
|
check_random_dataset, check_split, check_bucket_batch_by_length, check_cluedataset, check_save, check_csvdataset, \ |
|
|
|
check_paddeddataset, check_tuple_iterator, check_dict_iterator, check_schema, check_to_device_send |
|
|
|
from ..core.config import get_callback_timeout, _init_device_info |
|
|
|
from ..core.config import get_callback_timeout, _init_device_info, get_enable_shared_mem, get_num_parallel_workers, \ |
|
|
|
get_prefetch_size |
|
|
|
from ..core.datatypes import mstype_to_detype, mstypelist_to_detypelist |
|
|
|
from ..core.validator_helpers import replace_none |
|
|
|
|
|
|
|
@@ -1917,12 +1919,14 @@ class BatchDataset(Dataset): |
|
|
|
same). |
|
|
|
pad_info (dict, optional): Whether to perform padding on selected columns. pad_info={"col1":([224,224],0)} |
|
|
|
will pad column with name "col1" to a tensor of size [224,224] and fill the missing with 0. |
|
|
|
max_rowsize(int, optional): Maximum size of row in MB that is used for shared memory allocation to copy |
|
|
|
data between processes. This is only used if python_multiprocessing is set to True (default 16 MB). |
|
|
|
|
|
|
|
""" |
|
|
|
|
|
|
|
def __init__(self, input_dataset, batch_size, drop_remainder=False, num_parallel_workers=None, per_batch_map=None, |
|
|
|
input_columns=None, output_columns=None, column_order=None, pad_info=None, |
|
|
|
python_multiprocessing=False): |
|
|
|
python_multiprocessing=False, max_rowsize=16): |
|
|
|
super().__init__(children=input_dataset, num_parallel_workers=num_parallel_workers) |
|
|
|
|
|
|
|
if BatchDataset._is_ancestor_of_repeat(input_dataset): |
|
|
|
@@ -1948,6 +1952,7 @@ class BatchDataset(Dataset): |
|
|
|
self.python_multiprocessing = python_multiprocessing |
|
|
|
self.process_pool = None |
|
|
|
self.hook = None |
|
|
|
self.max_rowsize = max_rowsize |
|
|
|
|
|
|
|
def parse(self, children=None): |
|
|
|
return cde.BatchNode(children[0], self.batch_size, self.drop_remainder, self.pad, self.input_columns, |
|
|
|
@@ -1997,10 +2002,25 @@ class BatchDataset(Dataset): |
|
|
|
Per iterator bootstrap callback. |
|
|
|
""" |
|
|
|
if self.python_multiprocessing: |
|
|
|
arg_q_list = [] |
|
|
|
res_q_list = [] |
|
|
|
|
|
|
|
# If user didn't specify num_parallel_workers, set it to default |
|
|
|
if self.num_parallel_workers is not None: |
|
|
|
num_parallel = self.num_parallel_workers |
|
|
|
else: |
|
|
|
num_parallel = get_num_parallel_workers() |
|
|
|
|
|
|
|
if get_enable_shared_mem(): |
|
|
|
for _ in range(num_parallel): |
|
|
|
arg_q_list.append(_SharedQueue(1, max_rowsize=self.max_rowsize * self.batch_size)) |
|
|
|
res_q_list.append(_SharedQueue(1, max_rowsize=self.max_rowsize * self.batch_size)) |
|
|
|
|
|
|
|
# Construct pool with the callable list |
|
|
|
# The callable list and _pyfunc_worker_init are used to pass lambda function in to subprocesses |
|
|
|
self.process_pool = multiprocessing.Pool(processes=self.num_parallel_workers, |
|
|
|
initializer=_pyfunc_worker_init, initargs=([self.per_batch_map],)) |
|
|
|
self.process_pool = multiprocessing.Pool(processes=num_parallel, |
|
|
|
initializer=_pyfunc_worker_init, |
|
|
|
initargs=([self.per_batch_map], arg_q_list, res_q_list)) |
|
|
|
|
|
|
|
idx = 0 |
|
|
|
global _OP_NAME, _OP_PROCESS, _LOCK |
|
|
|
@@ -2013,7 +2033,7 @@ class BatchDataset(Dataset): |
|
|
|
_OP_PROCESS.update(process_id) |
|
|
|
|
|
|
|
# Wrap per_batch_map into _PythonCallable |
|
|
|
self.per_batch_map = _PythonCallable(self.per_batch_map, idx, self.process_pool) |
|
|
|
self.per_batch_map = _PythonCallable(self.per_batch_map, idx, self.process_pool, arg_q_list, res_q_list) |
|
|
|
self.hook = _ExceptHookHandler() |
|
|
|
atexit.register(_mp_pool_exit_preprocess) |
|
|
|
# If python version greater than 3.8, we need to close ThreadPool in atexit for unclean pool teardown. |
|
|
|
@@ -2211,6 +2231,8 @@ class ShuffleDataset(Dataset): |
|
|
|
# Pyfunc collection for multiprocess pyfunc |
|
|
|
# This global variable will only be used within subprocesses |
|
|
|
_GLOBAL_PYFUNC_LIST = [] |
|
|
|
_ARGS_QUEUE = [] |
|
|
|
_RET_QUEUE = [] |
|
|
|
_OP_NAME = dict() |
|
|
|
_OP_PROCESS = dict() |
|
|
|
_LOCK = threading.Lock() |
|
|
|
@@ -2219,22 +2241,37 @@ _LOCK = threading.Lock() |
|
|
|
# Pyfunc worker init function |
|
|
|
# Python multiprocessing library forbid sending lambda function through pipe. |
|
|
|
# This init function allow us to add all Python function to a global collection and then fork afterwards. |
|
|
|
def _pyfunc_worker_init(pyfunc_list): |
|
|
|
def _pyfunc_worker_init(pyfunc_list, args_queue, ret_queue): |
|
|
|
global _GLOBAL_PYFUNC_LIST |
|
|
|
global _ARGS_QUEUE |
|
|
|
global _RET_QUEUE |
|
|
|
_GLOBAL_PYFUNC_LIST = pyfunc_list |
|
|
|
_ARGS_QUEUE = args_queue |
|
|
|
_RET_QUEUE = ret_queue |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
# Pyfunc worker execution function |
|
|
|
# All exceptions will be raised to main processes |
|
|
|
def _pyfunc_worker_exec(index, *args): |
|
|
|
def _pyfunc_worker_exec(index, qid, *args): |
|
|
|
""" |
|
|
|
Internal function for call certain pyfunc in Python process. |
|
|
|
""" |
|
|
|
# Some threads in multiprocess.pool can't process sigint signal, |
|
|
|
# and will occur hang problem, so ctrl+c will pass to parent process. |
|
|
|
signal.signal(signal.SIGINT, signal.SIG_IGN) |
|
|
|
return _GLOBAL_PYFUNC_LIST[index](*args) |
|
|
|
|
|
|
|
if qid != -1: |
|
|
|
## Pass arguments through the Queue instead of directly to remote process |
|
|
|
args = _ARGS_QUEUE[qid].get() |
|
|
|
r = _GLOBAL_PYFUNC_LIST[index](*args) |
|
|
|
if isinstance(r, tuple): |
|
|
|
_RET_QUEUE[qid].put(r) |
|
|
|
else: |
|
|
|
_RET_QUEUE[qid].put((r,)) |
|
|
|
return [qid] |
|
|
|
## not using shared memory for passing arguments, call function directly |
|
|
|
return _GLOBAL_PYFUNC_LIST[index](*args) |
|
|
|
|
|
|
|
# PythonCallable wrapper for multiprocess pyfunc |
|
|
|
class _PythonCallable: |
|
|
|
@@ -2242,7 +2279,7 @@ class _PythonCallable: |
|
|
|
Internal Python function wrapper for multiprocessing pyfunc. |
|
|
|
""" |
|
|
|
|
|
|
|
def __init__(self, py_callable, idx, pool=None): |
|
|
|
def __init__(self, py_callable, idx, pool=None, arg_q=None, res_q=None): |
|
|
|
# Original Python callable from user. |
|
|
|
self.py_callable = py_callable |
|
|
|
# Process pool created for current iterator. |
|
|
|
@@ -2250,19 +2287,47 @@ class _PythonCallable: |
|
|
|
# Python callable index for subprocess _GLOBAL_PYFUNC_LIST |
|
|
|
self.idx = idx |
|
|
|
|
|
|
|
if pool is not None: |
|
|
|
self.queuemap = {} |
|
|
|
self.arg_q = arg_q |
|
|
|
self.res_q = res_q |
|
|
|
self.next_queue = 0 |
|
|
|
|
|
|
|
def __call__(self, *args): |
|
|
|
|
|
|
|
# note here: the RUN state of python3.7 and python3.8 is different: |
|
|
|
# python3.7: RUN = 0 |
|
|
|
# python3.8: RUN = "RUN" |
|
|
|
# so we use self.pool._state == RUN instead and we can't use _state == 0 any more. |
|
|
|
if self.pool is not None and self.pool._state == RUN and check_iterator_cleanup() is False: # pylint: disable=W0212 |
|
|
|
# This call will send the tensors along with Python callable index to the process pool. |
|
|
|
# Block, yield GIL. Current thread will reacquire GIL once result is returned. |
|
|
|
result = self.pool.apply_async(_pyfunc_worker_exec, [self.idx, *args]) |
|
|
|
# arg_q will have 0 size if we are not using shared memory |
|
|
|
# if using multi-processing shared queue instead of multiprocess arg passing |
|
|
|
if self.arg_q != []: |
|
|
|
tid = threading.get_ident() |
|
|
|
# Need to register each thread to use a different queue to send data to pool |
|
|
|
if not tid in self.queuemap: |
|
|
|
qid = self.next_queue |
|
|
|
self.next_queue = self.next_queue + 1 |
|
|
|
self.queuemap[tid] = qid |
|
|
|
else: |
|
|
|
qid = self.queuemap[tid] |
|
|
|
self.arg_q[qid].put(args) |
|
|
|
|
|
|
|
# This call will send the tensors along with Python callable index to the process pool. |
|
|
|
# Block, yield GIL. Current thread will reacquire GIL once result is returned. |
|
|
|
result = self.pool.apply_async(_pyfunc_worker_exec, [self.idx, qid, []]) |
|
|
|
else: |
|
|
|
result = self.pool.apply_async(_pyfunc_worker_exec, [self.idx, -1, *args]) |
|
|
|
|
|
|
|
# todo this check might be wrong |
|
|
|
while check_iterator_cleanup() is False: |
|
|
|
try: |
|
|
|
if self.arg_q != []: |
|
|
|
r = result.get(30) |
|
|
|
if r[0] != qid: |
|
|
|
raise Exception("In PyCallable, got results from wrong thread") |
|
|
|
r = self.res_q[qid].get() |
|
|
|
return r |
|
|
|
return result.get(30) |
|
|
|
except multiprocessing.TimeoutError: |
|
|
|
continue |
|
|
|
@@ -2318,13 +2383,15 @@ class MapDataset(Dataset): |
|
|
|
cache (DatasetCache, optional): Use tensor caching service to speed up dataset processing. |
|
|
|
(default=None, which means no cache is used). |
|
|
|
callbacks: (DSCallback, list[DSCallback], optional): List of Dataset callbacks to be called (Default=None) |
|
|
|
max_rowsize(int, optional): Maximum size of row in MB that is used for shared memory allocation to copy |
|
|
|
data between processes. This is only used if python_multiprocessing is set to True (default 16 MB). |
|
|
|
|
|
|
|
Raises: |
|
|
|
ValueError: If len(input_columns) != len(output_columns) and column_order is not specified. |
|
|
|
""" |
|
|
|
|
|
|
|
def __init__(self, input_dataset, operations=None, input_columns=None, output_columns=None, column_order=None, |
|
|
|
num_parallel_workers=None, python_multiprocessing=False, cache=None, callbacks=None): |
|
|
|
num_parallel_workers=None, python_multiprocessing=False, cache=None, callbacks=None, max_rowsize=16): |
|
|
|
super().__init__(children=input_dataset, num_parallel_workers=num_parallel_workers, cache=cache) |
|
|
|
self.operations = to_list(operations) |
|
|
|
self.operations = py_transforms.Compose.reduce(self.operations) |
|
|
|
@@ -2347,6 +2414,7 @@ class MapDataset(Dataset): |
|
|
|
self.hook = None |
|
|
|
|
|
|
|
self.callbacks = to_list(callbacks) |
|
|
|
self.max_rowsize = max_rowsize |
|
|
|
|
|
|
|
def parse(self, children=None): |
|
|
|
operations = [] |
|
|
|
@@ -2374,6 +2442,19 @@ class MapDataset(Dataset): |
|
|
|
if self.python_multiprocessing: |
|
|
|
iter_specific_operations = [] |
|
|
|
callable_list = [] |
|
|
|
arg_q_list = [] |
|
|
|
res_q_list = [] |
|
|
|
|
|
|
|
# If user didn't specify num_parallel_workers, set it to default |
|
|
|
if self.num_parallel_workers is not None: |
|
|
|
num_parallel = self.num_parallel_workers |
|
|
|
else: |
|
|
|
num_parallel = get_num_parallel_workers() |
|
|
|
|
|
|
|
if get_enable_shared_mem(): |
|
|
|
for _ in range(num_parallel): |
|
|
|
arg_q_list.append(_SharedQueue(1, max_rowsize=self.max_rowsize)) |
|
|
|
res_q_list.append(_SharedQueue(1, max_rowsize=self.max_rowsize)) |
|
|
|
|
|
|
|
# Pass #1, look for Python callables and build list |
|
|
|
for op in self.operations: |
|
|
|
@@ -2384,8 +2465,9 @@ class MapDataset(Dataset): |
|
|
|
if callable_list: |
|
|
|
# Construct pool with the callable list |
|
|
|
# The callable list and _pyfunc_worker_init are used to pass lambda function in to subprocesses |
|
|
|
self.process_pool = multiprocessing.Pool(processes=self.num_parallel_workers, |
|
|
|
initializer=_pyfunc_worker_init, initargs=(callable_list,)) |
|
|
|
self.process_pool = multiprocessing.Pool(processes=num_parallel, |
|
|
|
initializer=_pyfunc_worker_init, |
|
|
|
initargs=(callable_list, arg_q_list, res_q_list)) |
|
|
|
|
|
|
|
# Pass #2 |
|
|
|
idx = 0 |
|
|
|
@@ -2401,7 +2483,8 @@ class MapDataset(Dataset): |
|
|
|
# our c transforms is now callable and should not be run in Python multithreading |
|
|
|
if callable(op) and str(op).find("c_transform") < 0: |
|
|
|
# Wrap Python callable into _PythonCallable |
|
|
|
iter_specific_operations.append(_PythonCallable(op, idx, self.process_pool)) |
|
|
|
iter_specific_operations.append(_PythonCallable(op, idx, self.process_pool, |
|
|
|
arg_q_list, res_q_list)) |
|
|
|
idx += 1 |
|
|
|
else: |
|
|
|
# CPP ops remain the same |
|
|
|
@@ -3186,7 +3269,7 @@ class SamplerFn: |
|
|
|
Multiprocessing or multithread generator function wrapper master process. |
|
|
|
""" |
|
|
|
|
|
|
|
def __init__(self, dataset, num_worker, multi_process): |
|
|
|
def __init__(self, dataset, num_worker, multi_process, max_rowsize): |
|
|
|
self.workers = [] |
|
|
|
self.num_worker = num_worker |
|
|
|
self.multi_process = multi_process |
|
|
|
@@ -3199,9 +3282,16 @@ class SamplerFn: |
|
|
|
else: |
|
|
|
self.eof = threading.Event() |
|
|
|
# Create workers |
|
|
|
|
|
|
|
#get default queue size and adjust queuesize per worker if there are large # workers |
|
|
|
queue_size = get_prefetch_size() |
|
|
|
queue_size = min(queue_size, queue_size * 4 // num_worker) |
|
|
|
queue_size = max(2, queue_size) |
|
|
|
|
|
|
|
|
|
|
|
for _ in range(num_worker): |
|
|
|
if multi_process is True: |
|
|
|
worker = _GeneratorWorkerMp(dataset, self.eof) |
|
|
|
worker = _GeneratorWorkerMp(dataset, self.eof, max_rowsize, queue_size) |
|
|
|
worker.daemon = True |
|
|
|
# When multi processes fork a subprocess, the lock of the main process is copied to the subprocess, |
|
|
|
# which may cause deadlock. Therefore, the subprocess startup is performed in che initialization phase. |
|
|
|
@@ -3363,9 +3453,12 @@ class _GeneratorWorkerMp(multiprocessing.Process): |
|
|
|
Worker process for multiprocess Generator. |
|
|
|
""" |
|
|
|
|
|
|
|
def __init__(self, dataset, eof): |
|
|
|
self.idx_queue = multiprocessing.Queue(16) |
|
|
|
self.res_queue = multiprocessing.Queue(16) |
|
|
|
def __init__(self, dataset, eof, max_rowsize, queue_size): |
|
|
|
self.idx_queue = multiprocessing.Queue(queue_size) |
|
|
|
if get_enable_shared_mem(): |
|
|
|
self.res_queue = _SharedQueue(queue_size, max_rowsize=max_rowsize) |
|
|
|
else: |
|
|
|
self.res_queue = multiprocessing.Queue(queue_size) |
|
|
|
super().__init__(target=_generator_worker_loop, args=(dataset, self.idx_queue, self.res_queue, eof, True)) |
|
|
|
|
|
|
|
def put(self, item): |
|
|
|
@@ -3453,6 +3546,8 @@ class GeneratorDataset(MappableDataset): |
|
|
|
when num_shards is also specified. Random accessible input is required. |
|
|
|
python_multiprocessing (bool, optional): Parallelize Python operations with multiple worker process. This |
|
|
|
option could be beneficial if the Python operation is computational heavy (default=True). |
|
|
|
max_rowsize(int, optional): Maximum size of row in MB that is used for shared memory allocation to copy |
|
|
|
data between processes. This is only used if python_multiprocessing is set to True (default 6 MB). |
|
|
|
|
|
|
|
Examples: |
|
|
|
>>> import numpy as np |
|
|
|
@@ -3516,7 +3611,7 @@ class GeneratorDataset(MappableDataset): |
|
|
|
@check_generatordataset |
|
|
|
def __init__(self, source, column_names=None, column_types=None, schema=None, num_samples=None, |
|
|
|
num_parallel_workers=1, shuffle=None, sampler=None, num_shards=None, shard_id=None, |
|
|
|
python_multiprocessing=True): |
|
|
|
python_multiprocessing=True, max_rowsize=6): |
|
|
|
super().__init__(num_parallel_workers=num_parallel_workers, sampler=sampler, num_samples=num_samples, |
|
|
|
shuffle=shuffle, num_shards=num_shards, shard_id=shard_id) |
|
|
|
self.source = source |
|
|
|
@@ -3542,6 +3637,8 @@ class GeneratorDataset(MappableDataset): |
|
|
|
if hasattr(self.source, "__len__"): |
|
|
|
self.source_len = len(self.source) |
|
|
|
|
|
|
|
self.max_rowsize = max_rowsize |
|
|
|
|
|
|
|
def __deepcopy__(self, memodict): |
|
|
|
if id(self) in memodict: |
|
|
|
return memodict[id(self)] |
|
|
|
@@ -3550,7 +3647,8 @@ class GeneratorDataset(MappableDataset): |
|
|
|
sample_fn = None |
|
|
|
if new_op.sampler is not None and hasattr(self.source, "__getitem__"): |
|
|
|
if new_op.num_parallel_workers > 1: |
|
|
|
sample_fn = SamplerFn(self.source, new_op.num_parallel_workers, self.python_multiprocessing) |
|
|
|
sample_fn = SamplerFn(self.source, new_op.num_parallel_workers, self.python_multiprocessing, |
|
|
|
self.max_rowsize) |
|
|
|
new_op.prepared_source = (lambda sample_ids: _cpp_sampler_fn_mp(sample_ids, sample_fn)) |
|
|
|
else: |
|
|
|
new_op.prepared_source = (lambda sample_ids: _cpp_sampler_fn(sample_ids, self.source)) |
|
|
|
|