Browse Source

Add python multiprocessing support for Mindspore.dataset

tags/v0.2.0-alpha
Junhan Hu 5 years ago
parent
commit
b13e7bc31a
3 changed files with 204 additions and 3 deletions
  1. +94
    -3
      mindspore/dataset/engine/datasets.py
  2. +4
    -0
      mindspore/dataset/engine/iterators.py
  3. +106
    -0
      tests/ut/python/dataset/test_pyfunc.py

+ 94
- 3
mindspore/dataset/engine/datasets.py View File

@@ -24,6 +24,7 @@ import math
import os import os
import random import random
import uuid import uuid
import multiprocessing
from enum import Enum from enum import Enum
from importlib import import_module from importlib import import_module


@@ -231,7 +232,7 @@ class Dataset:


@check_map @check_map
def map(self, input_columns=None, operations=None, output_columns=None, columns_order=None, def map(self, input_columns=None, operations=None, output_columns=None, columns_order=None,
num_parallel_workers=None):
num_parallel_workers=None, python_multiprocessing=False):
""" """
Applies each operation in operations to this dataset. Applies each operation in operations to this dataset.


@@ -270,6 +271,8 @@ class Dataset:
same). same).
num_parallel_workers (int, optional): Number of threads used to process the dataset in num_parallel_workers (int, optional): Number of threads used to process the dataset in
parallel (default=None, the value from the config will be used). parallel (default=None, the value from the config will be used).
python_multiprocessing (bool, optional): Parallelize python operations with multiple worker process. This
option could be beneficial if the python operation is computational heavy (default=False).


Returns: Returns:
MapDataset, dataset after mapping operation. MapDataset, dataset after mapping operation.
@@ -383,7 +386,8 @@ class Dataset:
>>> columns_order = ["mod7", "mod3", "col1"] >>> columns_order = ["mod7", "mod3", "col1"]
>>> ds_mapped = ds_pyfunc.map(input_columns, operations, output_columns, columns_order) >>> ds_mapped = ds_pyfunc.map(input_columns, operations, output_columns, columns_order)
""" """
return MapDataset(self, input_columns, operations, output_columns, columns_order, num_parallel_workers)
return MapDataset(self, input_columns, operations, output_columns, columns_order, num_parallel_workers,
python_multiprocessing)


@check_repeat @check_repeat
def repeat(self, count=None): def repeat(self, count=None):
@@ -1041,6 +1045,55 @@ class ShuffleDataset(DatasetOp):
return args return args




# Pyfunc collection for multiprocess pyfunc
# This global variable will only be used within subprocesses
_GLOBAL_PYFUNC_LIST = []


# Pyfunc worker init function
# Python multiprocessing library forbid sending lambda function through pipe.
# This init function allow us to add all python function to a global collection and then fork afterwards.
def _pyfunc_worker_init(pyfunc_list):
global _GLOBAL_PYFUNC_LIST
_GLOBAL_PYFUNC_LIST = pyfunc_list


# Pyfunc worker execution function
# All exceptions will be raised to main processes
def _pyfunc_worker_exec(index, *args):
try:
return _GLOBAL_PYFUNC_LIST[index](*args)
except KeyboardInterrupt:
raise Exception("Multiprocess MapOp worker receives KeyboardInterrupt")


# PythonCallable wrapper for multiprocess pyfunc
class _PythonCallable:
"""
Internal python function wrapper for multiprocessing pyfunc
"""
def __init__(self, py_callable, idx, pool=None):
# Original python callable from user.
self.py_callable = py_callable
# Process pool created for current iterator.
self.pool = pool
# Python callable index for subprocess _GLOBAL_PYFUNC_LIST
self.idx = idx

def __call__(self, *args):
if self.pool is not None:
try:
# This call will send the tensors along with Python callable index to the process pool.
# Block, yield GIL. Current thread will reacquire GIL once result is returned.
return self.pool.apply(_pyfunc_worker_exec, [self.idx, *args])
except KeyboardInterrupt:
self.pool.terminate()
self.pool.join()
raise Exception("Multiprocess MapOp worker receives KeyboardInterrupt")
# Invoke original python callable in master process in case the pool is gone.
return self.py_callable(*args)


class MapDataset(DatasetOp): class MapDataset(DatasetOp):
""" """
The result of applying Map operator to the input Dataset. The result of applying Map operator to the input Dataset.
@@ -1060,13 +1113,15 @@ class MapDataset(DatasetOp):
The argument is mandatory if len(input_columns) != len(output_columns). The argument is mandatory if len(input_columns) != len(output_columns).
num_parallel_workers (int, optional): Number of workers to process the Dataset num_parallel_workers (int, optional): Number of workers to process the Dataset
in parallel (default=None). in parallel (default=None).
python_multiprocessing (bool, optional): Parallelize python operations with multiple worker process. This
option could be beneficial if the python operation is computational heavy (default=False).


Raises: Raises:
ValueError: If len(input_columns) != len(output_columns) and columns_order is not specified. ValueError: If len(input_columns) != len(output_columns) and columns_order is not specified.
""" """


def __init__(self, input_dataset, input_columns=None, operations=None, output_columns=None, columns_order=None, def __init__(self, input_dataset, input_columns=None, operations=None, output_columns=None, columns_order=None,
num_parallel_workers=None):
num_parallel_workers=None, python_multiprocessing=False):
super().__init__(num_parallel_workers) super().__init__(num_parallel_workers)
self.input.append(input_dataset) self.input.append(input_dataset)
if input_columns is not None and not isinstance(input_columns, list): if input_columns is not None and not isinstance(input_columns, list):
@@ -1087,6 +1142,8 @@ class MapDataset(DatasetOp):


input_dataset.output.append(self) input_dataset.output.append(self)
self._input_indexs = input_dataset.input_indexs self._input_indexs = input_dataset.input_indexs
self.python_multiprocessing = python_multiprocessing
self.process_pool = None


def get_args(self): def get_args(self):
args = super().get_args() args = super().get_args()
@@ -1104,6 +1161,40 @@ class MapDataset(DatasetOp):
""" """
return self.input[0].get_dataset_size() return self.input[0].get_dataset_size()


# Iterator bootstrap will be called on iterator construction.
# A deep copy of Dataset object is created prior of iterator_bootstrap.
# This method will create per iterator process pool and bind pyfunc execution to the pool.
def iterator_bootstrap(self):
"""
Per iterator bootstrap callback.
"""
if self.python_multiprocessing:
iter_specific_operations = []
callable_list = []

# Pass #1, look for python callables and build list
for op in self.operations:
if callable(op):
callable_list.append(op)

if callable_list:
# Construct pool with the callable list
# The callable list and _pyfunc_worker_init are used to pass lambda function in to subprocesses
self.process_pool = multiprocessing.Pool(processes=self.num_parallel_workers,
initializer=_pyfunc_worker_init,
initargs=(callable_list,))
# Pass #2
idx = 0
for op in self.operations:
if callable(op):
# Wrap python callable into _PythonCallable
iter_specific_operations.append(_PythonCallable(op, idx, self.process_pool))
idx += 1
else:
# CPP ops remain the same
iter_specific_operations.append(op)
self.operations = iter_specific_operations



class RepeatDataset(DatasetOp): class RepeatDataset(DatasetOp):
""" """


+ 4
- 0
mindspore/dataset/engine/iterators.py View File

@@ -63,6 +63,10 @@ def _alter_node(node):
return new_shuffle return new_shuffle


if isinstance(node, de.MapDataset): if isinstance(node, de.MapDataset):
if node.python_multiprocessing:
# Bootstrap can only be performed on a copy of the original dataset node.
# Bootstrap on original dataset node will make all iterators share the same process pool
node.iterator_bootstrap()
if node.columns_order is not None: if node.columns_order is not None:
# Remove the connection between the parent's node to the current node because we are inserting a node. # Remove the connection between the parent's node to the current node because we are inserting a node.
if node.output: if node.output:


+ 106
- 0
tests/ut/python/dataset/test_pyfunc.py View File

@@ -13,6 +13,7 @@
# limitations under the License. # limitations under the License.
# ============================================================================== # ==============================================================================
import numpy as np import numpy as np
import pytest


import mindspore.dataset as ds import mindspore.dataset as ds
from mindspore import log as logger from mindspore import log as logger
@@ -181,6 +182,106 @@ def test_case_6():
i = i + 4 i = i + 4




def test_case_7():
"""
Test PyFunc
"""
logger.info("Test 1-1 PyFunc Multiprocess: lambda x : x + x")

# apply dataset operations
data1 = ds.TFRecordDataset(DATA_DIR, SCHEMA_DIR, shuffle=False)

data1 = data1.map(input_columns="col0", output_columns="out", operations=(lambda x: x + x),
num_parallel_workers=4, python_multiprocessing = True)

i = 0
for item in data1.create_dict_iterator(): # each data is a dictionary
# In this test, the dataset is 2x2 sequential tensors
golden = np.array([[i * 2, (i + 1) * 2], [(i + 2) * 2, (i + 3) * 2]])
assert np.array_equal(item["out"], golden)
i = i + 4


def test_case_8():
"""
Test PyFunc
"""
logger.info("Test Multiprocess n-m PyFunc : lambda x, y : (x , x + 1, x + y)")

col = ["col0", "col1"]

# apply dataset operations
data1 = ds.TFRecordDataset(DATA_DIR, SCHEMA_DIR, shuffle=False)

data1 = data1.map(input_columns=col, output_columns=["out0", "out1", "out2"], num_parallel_workers=4,
operations=(lambda x, y: (x, x + y, x + y + 1)), columns_order=["out0", "out1", "out2"],
python_multiprocessing=True)

i = 0
for item in data1.create_dict_iterator(): # each data is a dictionary
# In this test, the dataset is 2x2 sequential tensors
golden = np.array([[i, i + 1], [i + 2, i + 3]])
assert np.array_equal(item["out0"], golden)
golden = np.array([[i * 2, (i + 1) * 2], [(i + 2) * 2, (i + 3) * 2]])
assert np.array_equal(item["out1"], golden)
golden = np.array([[i * 2 + 1, (i + 1) * 2 + 1], [(i + 2) * 2 + 1, (i + 3) * 2 + 1]])
assert np.array_equal(item["out2"], golden)
i = i + 4


def test_case_9():
"""
Test PyFunc
"""
logger.info("Test multiple 1-1 PyFunc Multiprocess: lambda x : x + x")

# apply dataset operations
data1 = ds.TFRecordDataset(DATA_DIR, SCHEMA_DIR, shuffle=False)

data1 = data1.map(input_columns="col0", output_columns="out", operations=[(lambda x: x + x), (lambda x: x + 1),
(lambda x: x + 2)],
num_parallel_workers=4, python_multiprocessing=True)

i = 0
for item in data1.create_dict_iterator(): # each data is a dictionary
# In this test, the dataset is 2x2 sequential tensors
golden = np.array([[i * 2 + 3, (i + 1) * 2 + 3], [(i + 2) * 2 + 3, (i + 3) * 2 + 3]])
assert np.array_equal(item["out"], golden)
i = i + 4


def test_pyfunc_execption():
logger.info("Test PyFunc Execption Throw: lambda x : raise Execption()")

def pyfunc(x):
raise Exception("Pyfunc Throw")

with pytest.raises(RuntimeError) as info:
# apply dataset operations
data1 = ds.TFRecordDataset(DATA_DIR, SCHEMA_DIR, shuffle=False)
data1 = data1.map(input_columns="col0", output_columns="out", operations= pyfunc,
num_parallel_workers=4)
for _ in data1:
pass
assert "Pyfunc Throw" in str(info.value)


def test_pyfunc_execption_multiprocess():
logger.info("Test Multiprocess PyFunc Execption Throw: lambda x : raise Execption()")

def pyfunc(x):
raise Exception("MP Pyfunc Throw")

with pytest.raises(RuntimeError) as info:
# apply dataset operations
data1 = ds.TFRecordDataset(DATA_DIR, SCHEMA_DIR, shuffle=False)
data1 = data1.map(input_columns="col0", output_columns="out", operations= pyfunc,
num_parallel_workers=4, python_multiprocessing = True)
for _ in data1:
pass
assert "MP Pyfunc Throw" in str(info.value)


if __name__ == "__main__": if __name__ == "__main__":
test_case_0() test_case_0()
test_case_1() test_case_1()
@@ -189,3 +290,8 @@ if __name__ == "__main__":
test_case_4() test_case_4()
test_case_5() test_case_5()
test_case_6() test_case_6()
test_case_7()
test_case_8()
test_case_9()
test_pyfunc_execption()
test_pyfunc_execption_multiprocess()

Loading…
Cancel
Save