Browse Source

Optimize multi-process in minddata by passing data via shared memory instead of multiprocessing.queue

tags/v1.3.0
RobinGrosman 5 years ago
parent
commit
9a7d1fc034
9 changed files with 329 additions and 25 deletions
  1. +2
    -0
      mindspore/ccsrc/minddata/dataset/api/python/bindings/dataset/core/bindings.cc
  2. +2
    -1
      mindspore/ccsrc/minddata/dataset/core/config_manager.cc
  3. +9
    -0
      mindspore/ccsrc/minddata/dataset/core/config_manager.h
  4. +27
    -1
      mindspore/dataset/core/config.py
  5. +121
    -23
      mindspore/dataset/engine/datasets.py
  6. +117
    -0
      mindspore/dataset/engine/queue.py
  7. +2
    -0
      tests/ut/cpp/dataset/client_config_test.cc
  8. +27
    -0
      tests/ut/python/dataset/test_datasets_generator.py
  9. +22
    -0
      tests/ut/python/dataset/test_pyfunc.py

+ 2
- 0
mindspore/ccsrc/minddata/dataset/api/python/bindings/dataset/core/bindings.cc View File

@@ -55,6 +55,8 @@ PYBIND_REGISTER(ConfigManager, 0, ([](const py::module *m) {
.def("set_op_connector_size", &ConfigManager::set_op_connector_size)
.def("set_seed", &ConfigManager::set_seed)
.def("set_worker_connector_size", &ConfigManager::set_worker_connector_size)
.def("set_enable_shared_mem", &ConfigManager::set_enable_shared_mem)
.def("get_enable_shared_mem", &ConfigManager::enable_shared_mem)
.def("load", [](ConfigManager &c, std::string s) { THROW_IF_ERROR(c.LoadFile(s)); });
}));



+ 2
- 1
mindspore/ccsrc/minddata/dataset/core/config_manager.cc View File

@@ -49,7 +49,8 @@ ConfigManager::ConfigManager()
auto_num_workers_(kDftAutoNumWorkers),
num_cpu_threads_(std::thread::hardware_concurrency()),
auto_num_workers_num_shards_(1),
auto_worker_config_(0) {
auto_worker_config_(0),
enable_shared_mem_(true) {
num_cpu_threads_ = num_cpu_threads_ > 0 ? num_cpu_threads_ : std::numeric_limits<uint16_t>::max();
num_parallel_workers_ = num_parallel_workers_ < num_cpu_threads_ ? num_parallel_workers_ : num_cpu_threads_;
auto env_cache_host = std::getenv("MS_CACHE_HOST");


+ 9
- 0
mindspore/ccsrc/minddata/dataset/core/config_manager.h View File

@@ -222,6 +222,14 @@ class ConfigManager {
// @return The experimental config used by AutoNumWorker, each 1 refers to a different setup configuration
void set_auto_worker_config_(uint8_t cfg) { auto_worker_config_ = cfg; }

// setter function
// @param enable - To enable multiprocessing to use shared memory
void set_enable_shared_mem(bool enable) { enable_shared_mem_ = enable; }

// getter function
// @return - Flag to indicate whether shared memory for multi-processing is enabled
bool enable_shared_mem() { return enable_shared_mem_; }

private:
int32_t num_parallel_workers_;
int32_t worker_connector_size_;
@@ -244,6 +252,7 @@ class ConfigManager {
int32_t num_cpu_threads_;
int32_t auto_num_workers_num_shards_;
uint8_t auto_worker_config_;
bool enable_shared_mem_;
// Private helper function that takes a nlohmann json format and populates the settings
// @param j - The json nlohmann json info
Status FromJson(const nlohmann::json &j);


+ 27
- 1
mindspore/dataset/core/config.py View File

@@ -26,7 +26,7 @@ from mindspore import log as logger
__all__ = ['set_seed', 'get_seed', 'set_prefetch_size', 'get_prefetch_size', 'set_num_parallel_workers',
'get_num_parallel_workers', 'set_numa_enable', 'get_numa_enable', 'set_monitor_sampling_interval',
'get_monitor_sampling_interval', 'load', 'get_callback_timeout', 'set_auto_num_workers',
'get_auto_num_workers', '_init_device_info']
'get_auto_num_workers', '_init_device_info', 'set_enable_shared_mem', 'get_enable_shared_mem']

INT32_MAX = 2147483647
UINT32_MAX = 4294967295
@@ -374,3 +374,29 @@ def _stop_dataset_profiler():
_config.stop_dataset_profiler(True)
logger.warning("Profiling: waiting for dataset part profiling stop.")
time.sleep(1)

def get_enable_shared_mem():
"""
Get the default state of shared mem enabled variable.


Returns:
bool, the state of shared mem enabled variable (default: True).
"""
return _config.get_enable_shared_mem()

def set_enable_shared_mem(enable):
"""
Set the default state of shared memory flag. If shared_mem_enable is True, will use shared memory queues
to pass data to processes that are created for operators that set multiprocessing=True.

Args:
enable (bool): Whether to use shared memory in operators with "multiprocessing=True"

Raises:
TypeError: If enable is not a boolean data type.

Examples:
>>> ds.config.set_enable_shared_mem(True)
"""
_config.set_enable_shared_mem(enable)

+ 121
- 23
mindspore/dataset/engine/datasets.py View File

@@ -51,6 +51,7 @@ import mindspore.dataset.transforms.py_transforms as py_transforms
from . import samplers
from .iterators import DictIterator, TupleIterator, DummyIterator, check_iterator_cleanup, _set_iterator_cleanup, \
ITERATORS_LIST, _unset_iterator_cleanup
from .queue import _SharedQueue
from .validators import check_batch, check_shuffle, check_map, check_filter, check_repeat, check_skip, check_zip, \
check_rename, check_numpyslicesdataset, check_device_send, \
check_take, check_project, check_imagefolderdataset, check_mnist_cifar_dataset, check_manifestdataset, \
@@ -58,7 +59,8 @@ from .validators import check_batch, check_shuffle, check_map, check_filter, che
check_generatordataset, check_sync_wait, check_zip_dataset, check_add_column, check_textfiledataset, check_concat, \
check_random_dataset, check_split, check_bucket_batch_by_length, check_cluedataset, check_save, check_csvdataset, \
check_paddeddataset, check_tuple_iterator, check_dict_iterator, check_schema, check_to_device_send
from ..core.config import get_callback_timeout, _init_device_info
from ..core.config import get_callback_timeout, _init_device_info, get_enable_shared_mem, get_num_parallel_workers, \
get_prefetch_size
from ..core.datatypes import mstype_to_detype, mstypelist_to_detypelist
from ..core.validator_helpers import replace_none

@@ -1917,12 +1919,14 @@ class BatchDataset(Dataset):
same).
pad_info (dict, optional): Whether to perform padding on selected columns. pad_info={"col1":([224,224],0)}
will pad column with name "col1" to a tensor of size [224,224] and fill the missing with 0.
max_rowsize(int, optional): Maximum size of row in MB that is used for shared memory allocation to copy
data between processes. This is only used if python_multiprocessing is set to True (default 16 MB).

"""

def __init__(self, input_dataset, batch_size, drop_remainder=False, num_parallel_workers=None, per_batch_map=None,
input_columns=None, output_columns=None, column_order=None, pad_info=None,
python_multiprocessing=False):
python_multiprocessing=False, max_rowsize=16):
super().__init__(children=input_dataset, num_parallel_workers=num_parallel_workers)

if BatchDataset._is_ancestor_of_repeat(input_dataset):
@@ -1948,6 +1952,7 @@ class BatchDataset(Dataset):
self.python_multiprocessing = python_multiprocessing
self.process_pool = None
self.hook = None
self.max_rowsize = max_rowsize

def parse(self, children=None):
return cde.BatchNode(children[0], self.batch_size, self.drop_remainder, self.pad, self.input_columns,
@@ -1997,10 +2002,25 @@ class BatchDataset(Dataset):
Per iterator bootstrap callback.
"""
if self.python_multiprocessing:
arg_q_list = []
res_q_list = []

# If user didn't specify num_parallel_workers, set it to default
if self.num_parallel_workers is not None:
num_parallel = self.num_parallel_workers
else:
num_parallel = get_num_parallel_workers()

if get_enable_shared_mem():
for _ in range(num_parallel):
arg_q_list.append(_SharedQueue(1, max_rowsize=self.max_rowsize * self.batch_size))
res_q_list.append(_SharedQueue(1, max_rowsize=self.max_rowsize * self.batch_size))

# Construct pool with the callable list
# The callable list and _pyfunc_worker_init are used to pass lambda function in to subprocesses
self.process_pool = multiprocessing.Pool(processes=self.num_parallel_workers,
initializer=_pyfunc_worker_init, initargs=([self.per_batch_map],))
self.process_pool = multiprocessing.Pool(processes=num_parallel,
initializer=_pyfunc_worker_init,
initargs=([self.per_batch_map], arg_q_list, res_q_list))

idx = 0
global _OP_NAME, _OP_PROCESS, _LOCK
@@ -2013,7 +2033,7 @@ class BatchDataset(Dataset):
_OP_PROCESS.update(process_id)

# Wrap per_batch_map into _PythonCallable
self.per_batch_map = _PythonCallable(self.per_batch_map, idx, self.process_pool)
self.per_batch_map = _PythonCallable(self.per_batch_map, idx, self.process_pool, arg_q_list, res_q_list)
self.hook = _ExceptHookHandler()
atexit.register(_mp_pool_exit_preprocess)
# If python version greater than 3.8, we need to close ThreadPool in atexit for unclean pool teardown.
@@ -2211,6 +2231,8 @@ class ShuffleDataset(Dataset):
# Pyfunc collection for multiprocess pyfunc
# This global variable will only be used within subprocesses
_GLOBAL_PYFUNC_LIST = []
_ARGS_QUEUE = []
_RET_QUEUE = []
_OP_NAME = dict()
_OP_PROCESS = dict()
_LOCK = threading.Lock()
@@ -2219,22 +2241,37 @@ _LOCK = threading.Lock()
# Pyfunc worker init function
# Python multiprocessing library forbid sending lambda function through pipe.
# This init function allow us to add all Python function to a global collection and then fork afterwards.
def _pyfunc_worker_init(pyfunc_list):
def _pyfunc_worker_init(pyfunc_list, args_queue, ret_queue):
global _GLOBAL_PYFUNC_LIST
global _ARGS_QUEUE
global _RET_QUEUE
_GLOBAL_PYFUNC_LIST = pyfunc_list
_ARGS_QUEUE = args_queue
_RET_QUEUE = ret_queue



# Pyfunc worker execution function
# All exceptions will be raised to main processes
def _pyfunc_worker_exec(index, *args):
def _pyfunc_worker_exec(index, qid, *args):
"""
Internal function for call certain pyfunc in Python process.
"""
# Some threads in multiprocess.pool can't process sigint signal,
# and will occur hang problem, so ctrl+c will pass to parent process.
signal.signal(signal.SIGINT, signal.SIG_IGN)
return _GLOBAL_PYFUNC_LIST[index](*args)

if qid != -1:
## Pass arguments through the Queue instead of directly to remote process
args = _ARGS_QUEUE[qid].get()
r = _GLOBAL_PYFUNC_LIST[index](*args)
if isinstance(r, tuple):
_RET_QUEUE[qid].put(r)
else:
_RET_QUEUE[qid].put((r,))
return [qid]
## not using shared memory for passing arguments, call function directly
return _GLOBAL_PYFUNC_LIST[index](*args)

# PythonCallable wrapper for multiprocess pyfunc
class _PythonCallable:
@@ -2242,7 +2279,7 @@ class _PythonCallable:
Internal Python function wrapper for multiprocessing pyfunc.
"""

def __init__(self, py_callable, idx, pool=None):
def __init__(self, py_callable, idx, pool=None, arg_q=None, res_q=None):
# Original Python callable from user.
self.py_callable = py_callable
# Process pool created for current iterator.
@@ -2250,19 +2287,47 @@ class _PythonCallable:
# Python callable index for subprocess _GLOBAL_PYFUNC_LIST
self.idx = idx

if pool is not None:
self.queuemap = {}
self.arg_q = arg_q
self.res_q = res_q
self.next_queue = 0

def __call__(self, *args):

# note here: the RUN state of python3.7 and python3.8 is different:
# python3.7: RUN = 0
# python3.8: RUN = "RUN"
# so we use self.pool._state == RUN instead and we can't use _state == 0 any more.
if self.pool is not None and self.pool._state == RUN and check_iterator_cleanup() is False: # pylint: disable=W0212
# This call will send the tensors along with Python callable index to the process pool.
# Block, yield GIL. Current thread will reacquire GIL once result is returned.
result = self.pool.apply_async(_pyfunc_worker_exec, [self.idx, *args])
# arg_q will have 0 size if we are not using shared memory
# if using multi-processing shared queue instead of multiprocess arg passing
if self.arg_q != []:
tid = threading.get_ident()
# Need to register each thread to use a different queue to send data to pool
if not tid in self.queuemap:
qid = self.next_queue
self.next_queue = self.next_queue + 1
self.queuemap[tid] = qid
else:
qid = self.queuemap[tid]
self.arg_q[qid].put(args)

# This call will send the tensors along with Python callable index to the process pool.
# Block, yield GIL. Current thread will reacquire GIL once result is returned.
result = self.pool.apply_async(_pyfunc_worker_exec, [self.idx, qid, []])
else:
result = self.pool.apply_async(_pyfunc_worker_exec, [self.idx, -1, *args])

# todo this check might be wrong
while check_iterator_cleanup() is False:
try:
if self.arg_q != []:
r = result.get(30)
if r[0] != qid:
raise Exception("In PyCallable, got results from wrong thread")
r = self.res_q[qid].get()
return r
return result.get(30)
except multiprocessing.TimeoutError:
continue
@@ -2318,13 +2383,15 @@ class MapDataset(Dataset):
cache (DatasetCache, optional): Use tensor caching service to speed up dataset processing.
(default=None, which means no cache is used).
callbacks: (DSCallback, list[DSCallback], optional): List of Dataset callbacks to be called (Default=None)
max_rowsize(int, optional): Maximum size of row in MB that is used for shared memory allocation to copy
data between processes. This is only used if python_multiprocessing is set to True (default 16 MB).

Raises:
ValueError: If len(input_columns) != len(output_columns) and column_order is not specified.
"""

def __init__(self, input_dataset, operations=None, input_columns=None, output_columns=None, column_order=None,
num_parallel_workers=None, python_multiprocessing=False, cache=None, callbacks=None):
num_parallel_workers=None, python_multiprocessing=False, cache=None, callbacks=None, max_rowsize=16):
super().__init__(children=input_dataset, num_parallel_workers=num_parallel_workers, cache=cache)
self.operations = to_list(operations)
self.operations = py_transforms.Compose.reduce(self.operations)
@@ -2347,6 +2414,7 @@ class MapDataset(Dataset):
self.hook = None

self.callbacks = to_list(callbacks)
self.max_rowsize = max_rowsize

def parse(self, children=None):
operations = []
@@ -2374,6 +2442,19 @@ class MapDataset(Dataset):
if self.python_multiprocessing:
iter_specific_operations = []
callable_list = []
arg_q_list = []
res_q_list = []

# If user didn't specify num_parallel_workers, set it to default
if self.num_parallel_workers is not None:
num_parallel = self.num_parallel_workers
else:
num_parallel = get_num_parallel_workers()

if get_enable_shared_mem():
for _ in range(num_parallel):
arg_q_list.append(_SharedQueue(1, max_rowsize=self.max_rowsize))
res_q_list.append(_SharedQueue(1, max_rowsize=self.max_rowsize))

# Pass #1, look for Python callables and build list
for op in self.operations:
@@ -2384,8 +2465,9 @@ class MapDataset(Dataset):
if callable_list:
# Construct pool with the callable list
# The callable list and _pyfunc_worker_init are used to pass lambda function in to subprocesses
self.process_pool = multiprocessing.Pool(processes=self.num_parallel_workers,
initializer=_pyfunc_worker_init, initargs=(callable_list,))
self.process_pool = multiprocessing.Pool(processes=num_parallel,
initializer=_pyfunc_worker_init,
initargs=(callable_list, arg_q_list, res_q_list))

# Pass #2
idx = 0
@@ -2401,7 +2483,8 @@ class MapDataset(Dataset):
# our c transforms is now callable and should not be run in Python multithreading
if callable(op) and str(op).find("c_transform") < 0:
# Wrap Python callable into _PythonCallable
iter_specific_operations.append(_PythonCallable(op, idx, self.process_pool))
iter_specific_operations.append(_PythonCallable(op, idx, self.process_pool,
arg_q_list, res_q_list))
idx += 1
else:
# CPP ops remain the same
@@ -3186,7 +3269,7 @@ class SamplerFn:
Multiprocessing or multithread generator function wrapper master process.
"""

def __init__(self, dataset, num_worker, multi_process):
def __init__(self, dataset, num_worker, multi_process, max_rowsize):
self.workers = []
self.num_worker = num_worker
self.multi_process = multi_process
@@ -3199,9 +3282,16 @@ class SamplerFn:
else:
self.eof = threading.Event()
# Create workers

#get default queue size and adjust queuesize per worker if there are large # workers
queue_size = get_prefetch_size()
queue_size = min(queue_size, queue_size * 4 // num_worker)
queue_size = max(2, queue_size)


for _ in range(num_worker):
if multi_process is True:
worker = _GeneratorWorkerMp(dataset, self.eof)
worker = _GeneratorWorkerMp(dataset, self.eof, max_rowsize, queue_size)
worker.daemon = True
# When multi processes fork a subprocess, the lock of the main process is copied to the subprocess,
# which may cause deadlock. Therefore, the subprocess startup is performed in che initialization phase.
@@ -3363,9 +3453,12 @@ class _GeneratorWorkerMp(multiprocessing.Process):
Worker process for multiprocess Generator.
"""

def __init__(self, dataset, eof):
self.idx_queue = multiprocessing.Queue(16)
self.res_queue = multiprocessing.Queue(16)
def __init__(self, dataset, eof, max_rowsize, queue_size):
self.idx_queue = multiprocessing.Queue(queue_size)
if get_enable_shared_mem():
self.res_queue = _SharedQueue(queue_size, max_rowsize=max_rowsize)
else:
self.res_queue = multiprocessing.Queue(queue_size)
super().__init__(target=_generator_worker_loop, args=(dataset, self.idx_queue, self.res_queue, eof, True))

def put(self, item):
@@ -3453,6 +3546,8 @@ class GeneratorDataset(MappableDataset):
when num_shards is also specified. Random accessible input is required.
python_multiprocessing (bool, optional): Parallelize Python operations with multiple worker process. This
option could be beneficial if the Python operation is computational heavy (default=True).
max_rowsize(int, optional): Maximum size of row in MB that is used for shared memory allocation to copy
data between processes. This is only used if python_multiprocessing is set to True (default 6 MB).

Examples:
>>> import numpy as np
@@ -3516,7 +3611,7 @@ class GeneratorDataset(MappableDataset):
@check_generatordataset
def __init__(self, source, column_names=None, column_types=None, schema=None, num_samples=None,
num_parallel_workers=1, shuffle=None, sampler=None, num_shards=None, shard_id=None,
python_multiprocessing=True):
python_multiprocessing=True, max_rowsize=6):
super().__init__(num_parallel_workers=num_parallel_workers, sampler=sampler, num_samples=num_samples,
shuffle=shuffle, num_shards=num_shards, shard_id=shard_id)
self.source = source
@@ -3542,6 +3637,8 @@ class GeneratorDataset(MappableDataset):
if hasattr(self.source, "__len__"):
self.source_len = len(self.source)

self.max_rowsize = max_rowsize

def __deepcopy__(self, memodict):
if id(self) in memodict:
return memodict[id(self)]
@@ -3550,7 +3647,8 @@ class GeneratorDataset(MappableDataset):
sample_fn = None
if new_op.sampler is not None and hasattr(self.source, "__getitem__"):
if new_op.num_parallel_workers > 1:
sample_fn = SamplerFn(self.source, new_op.num_parallel_workers, self.python_multiprocessing)
sample_fn = SamplerFn(self.source, new_op.num_parallel_workers, self.python_multiprocessing,
self.max_rowsize)
new_op.prepared_source = (lambda sample_ids: _cpp_sampler_fn_mp(sample_ids, sample_fn))
else:
new_op.prepared_source = (lambda sample_ids: _cpp_sampler_fn(sample_ids, self.source))


+ 117
- 0
mindspore/dataset/engine/queue.py View File

@@ -0,0 +1,117 @@
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""
This dataset module creates an internal queue class to more optimally pass data
between multiple processes in python. It has same API as multiprocessing.queue
but it will pass large data through shared memory.
"""

import multiprocessing.queues
import multiprocessing
import numpy as np
from mindspore import log as logger


class _SharedQueue(multiprocessing.queues.Queue):
"""
Class to implement a queue using shared memory for better performance.
Args:
size: Number of elements in the queue.
copy_out: Flag to indidcate whether an extra copy should be done before returning. If data will immediately be
copied before returning, then this can be set to False.
max_rowsize: Maximum size of any element in the Queue in MB.
"""
def __init__(self, size, copy_out=False, max_rowsize=6):
super().__init__(size, ctx=multiprocessing.get_context())

self.copy_out = copy_out

# change max_rowsize in MB into bytes
self.seg_size = max_rowsize * 1024 * 1024
##pipe can hold up to 65,636 bytes at a time
self.min_shared_mem = 10000
self.shm_list = []
self.seg_pos = 0
# num_seg has to be 2 more than the queue size. We can have remote worker filling a buffer, main process
# reading a buffer and also have a full queue of buffers in the meta-data queue
self.num_seg = size + 2
self.data_immediate = 0
self.data_shared = 1
self.print_error = True

try:
for _ in range(self.num_seg):
a = multiprocessing.Array('b', self.seg_size)
self.shm_list.append(a)
except:
raise "_SharedQueue: Error allocating " + str(self.seg_size) + "bytes, " + str(self.num_seg) + " elements."

def put(self, data, timeout=None):
name_list = []
count = 0
start_bytes = 0
for r in data:
if (isinstance(r, np.ndarray) and r.size > self.min_shared_mem and
start_bytes + r.nbytes < self.seg_size):
##need to convert start_bytes to offset in array
start_offset = start_bytes // r.dtype.itemsize
dest = np.ndarray(r.shape, r.dtype, buffer=self.shm_list[self.seg_pos].get_obj(), offset=start_offset)
np.copyto(dest, r)
byte = r.nbytes
byte = 8 * ((byte + 7) // 8)
start_bytes += byte
name_list.append((self.data_shared, self.seg_pos, byte, r.dtype, r.shape))
count += 1
else:
if isinstance(r, np.ndarray) and r.size >= self.min_shared_mem:
## Only print out error the first time it happens
if self.print_error:
logger.warning("Using shared memory queue, but rowsize is larger than allocated memory " +
"max_rowsize " + str(self.seg_size) + " current rowwize " +
str(start_bytes + r.nbytes))
self.print_error = False
name_list.append((self.data_immediate, r))

super().put(name_list, timeout=timeout)
## note above could generate a queue full exception. It will be handled by teh caller
## only increment seg_pos after successfully adding to metadata queue

if start_bytes > 0:
self.seg_pos = (self.seg_pos +1) % self.num_seg

def get(self, timeout=None):
result = super().get(timeout=timeout)
r = []
start_bytes = 0
for x in result:
if x[0] == self.data_shared:
seg_pos = x[1]
byte = x[2]
dtype = x[3]
shape = x[4]
start_offset = start_bytes // dtype.itemsize
b = self.shm_list[seg_pos]
data = np.ndarray(shape, dtype, buffer=b.get_obj(), offset=start_offset)
start_bytes += byte
if self.copy_out:
data2 = np.copy(data)
r.append(data2)
else:
r.append(data)
elif x[0] == self.data_immediate:
r.append(x[1])
else:
raise "SharedQueue, invalid entry in metadata."
return tuple(r)

+ 2
- 0
tests/ut/cpp/dataset/client_config_test.cc View File

@@ -53,12 +53,14 @@ TEST_F(MindDataTestClientConfig, TestClientConfig1) {
my_conf->set_worker_connector_size(3);
my_conf->set_op_connector_size(4);
my_conf->set_seed(5);
my_conf->set_enable_shared_mem(false);


ASSERT_EQ(my_conf->num_parallel_workers(), 2);
ASSERT_EQ(my_conf->worker_connector_size(), 3);
ASSERT_EQ(my_conf->op_connector_size(), 4);
ASSERT_EQ(my_conf->seed(), 5);
ASSERT_EQ(my_conf->enable_shared_mem(), false);

std::string file = datasets_root_path_ + "/declient.cfg";
ASSERT_TRUE(my_conf->LoadFile(file));


+ 27
- 0
tests/ut/python/dataset/test_datasets_generator.py View File

@@ -476,6 +476,32 @@ def test_generator_17():
np.testing.assert_array_equal(item["col1"], golden)
i = i + 1

def test_generator_18():
"""
Test multiprocessing flag (same as test 13 with python_multiprocessing=True flag)
"""
logger.info("Test map column order when input_columns is None.")

# apply dataset operations
data1 = ds.GeneratorDataset(generator_mc(2048), ["col0", "col1"], python_multiprocessing=True)
data1 = data1.map(operations=(lambda x: (x * 5)), output_columns=["out0"], num_parallel_workers=2,
python_multiprocessing=True)

# Expected column order is |out0|col1|
i = 0
for item in data1.create_tuple_iterator(num_epochs=1, output_numpy=True):
assert len(item) == 2
golden = np.array([i * 5])
np.testing.assert_array_equal(item[0], golden)
golden = np.array([[i, i + 1], [i + 2, i + 3]])
np.testing.assert_array_equal(item[1], golden)
i = i + 1

for item in data1.create_dict_iterator(num_epochs=1, output_numpy=True): # each data is a dictionary
# len should be 2 because col0 is dropped (not included in column_order)
assert len(item) == 2
golden = np.array([i * 5])
np.testing.assert_array_equal(item["out0"], golden)

def test_generator_error_1():
def generator_np():
@@ -777,6 +803,7 @@ if __name__ == "__main__":
test_generator_15()
test_generator_16()
test_generator_17()
test_generator_18()
test_generator_error_1()
test_generator_error_2()
test_generator_error_3()


+ 22
- 0
tests/ut/python/dataset/test_pyfunc.py View File

@@ -250,6 +250,28 @@ def test_case_9():
i = i + 4


def test_case_10():
"""
Test PyFunc
"""
logger.info("Test multiple map with multiprocess: lambda x : x + x")

# apply dataset operations
data1 = ds.TFRecordDataset(DATA_DIR, SCHEMA_DIR, shuffle=False)

data1 = data1.map(operations=[(lambda x: x * 10)], input_columns="col0",
output_columns="out", num_parallel_workers=4, python_multiprocessing=True)
data1 = data1.map(operations=[(lambda x: x + x), (lambda x: x + 1), (lambda x: x + 2)], input_columns="out",
output_columns="out", num_parallel_workers=4, python_multiprocessing=True)

i = 0
for item in data1.create_dict_iterator(num_epochs=1, output_numpy=True): # each data is a dictionary
# In this test, the dataset is 2x2 sequential tensors
golden = np.array([[i * 20 + 3, (i + 1) * 20 + 3], [(i + 2) * 20 + 3, (i + 3) * 20 + 3]])
np.testing.assert_array_equal(item["out"], golden)
i = i + 4


def test_pyfunc_implicit_compose():
"""
Test Implicit Compose with pyfunc


Loading…
Cancel
Save